序列到序列学习(sequence to sequence learning, Seq2seq)是将一个输入的单词序列转换为另一个输出的单词序列的任务,相当于有条件的语言生成。自然语言处理、语音处理等领域中的机器翻译、摘要生成、对话生成、语音识别等任务都属于这类问题。
序列到序列模型是执行这种任务的神经网络,由编码器网络和解码器网络组成。编码器将输入的单词序列转换成中间表示的序列(编码),解码器将中间表示的序列转换成输出的单词序列(解码)。有代表性的模型有基本模型、RNN Search、Transformer 模型等。基本模型使用 RNN 实现编码和解码,只讲编码器最终位置的中间表示传递到解码器。RNN Search 模型也以 RNN 为编码器和解码器,使用注意力机制将编码器的各个位置的中间表示有选择地传递到解码器。Transformer 模型完全基于注意力机制,使用注意力实现编码、解码以及编码器和解码器之间的信息传递。注意力实现相似或相关向量的检索计算,可以有效地表示概念的组合,是深度学习的核心技术。
假设输入的单词序列是 x1,x2,⋯,xmx_1,x_2,\cdots,x_mx1,x2,⋯,xm,输出的单词序列是 y1,y2,⋯,yny_1,y_2,\cdots,y_ny1,y2,⋯,yn,单词都来自词表。给定输入单词序列条件下输出单词序列的条件概率是P(y1,y2,⋯,yn∣x1,x2,⋯,xm)=∏i=1nP(yi∣y1,y2,⋯,yi−1,x1,x2,⋯,xm)P(y_1,y_2,\cdots,y_n|x_1,x_2,\cdots,x_m)=\prod_{i=1}^n P(y_i|y_1,y_2,\cdots,y_{i-1},x_1,x_2,\cdots,x_m) P(y1,y2,⋯,yn∣x1,x2,⋯,xm)=i=1∏nP(yi∣y1,y2,⋯,yi−1,x1,x2,⋯,xm) 即预测给定输入序列及已生成输出序列的条件下,下一个位置上单词出现的条件概率。
序列到序列模型由编码器和解码器网络组成。编码器将输入单词序列 x1,x2,⋯,xmx_1,x_2,\cdots,x_mx1,x2,⋯,xm 转换成中间表示序列 z1,z2,⋯,zmz_1,z_2,\cdots,z_mz1,z2,⋯,zm,每一个中间表示是一个实数向量。解码器根据中间表示序列 z1,z2,⋯,zmz_1,z_2,\cdots,z_mz1,z2,⋯,zm 依次生成输出单词序列 y1,y2,⋯,yny_1,y_2,\cdots,y_ny1,y2,⋯,yn。编码器网络可以写作
z1,z2,⋯,zm=F(x1,x2,⋯,xm)z_1,z_2,\cdots,z_m=F(x_1,x_2,\cdots,x_m) z1,z2,⋯,zm=F(x1,x2,⋯,xm) 编码器定义在整体输入单词序列上。解码器网络可以写作
P(yi∣y1,y2,⋯,yi−1,x1,x2,⋯,xm)=G(y1,y2,⋯,yi−1,z1,z2,⋯,zm)P(y_i|y_1,y_2,\cdots,y_{i-1},x_1,x_2,\cdots,x_m)=G(y_1,y_2,\cdots,y_{i-1},z_1,z_2,\cdots,z_m) P(yi∣y1,y2,⋯,yi−1,x1,x2,⋯,xm)=G(y1,y2,⋯,yi−1,z1,z2,⋯,zm) 解码器定义在单词输出序列的每一个位置上。
序列到序列学习有几个特点:编码器和解码器联合训练、反向传播、强制教学。学习时,训练数据的每一个样本由一个输入单词序列和一个输出单词序列组成。利用大量样本通过端到端学习的方式进行模型的参数估计。学习是强制教学的,可以在所有位置上并行处理。
序列到序列学习的预测通常使用束搜索(beam search)。目标是计算给定输入单词序列条件下概率最大的输出单词序列,束搜索用递归的方法近似计算条件概率最大的 kkk 个输出单词序列,其中 kkk 是束宽。
基本模型使用 RNN 实现编码和解码,实际上是一个有条件的 RNN 语言模型。RNN 通常是 LSTM 和 GRU。基本模型编码器是 RNN,如 LSTM,状态是
hj=a(xj,hj−1),j=1,2,⋯,m\bm{h}_j = a(\bm{x}_j,\bm{h}_{j-1}),\quad j=1,2,\cdots,m hj=a(xj,hj−1),j=1,2,⋯,m xj\bm{x}_jxj 是当前位置的输入单词的词向量,aaa 是处理单元,如 LSTM 单元;假设 h0=0\bm{h}_0=0h0=0.
解码器也是 RNN,如 LSTM,状态是
si=a(yi−1,si−1),i=1,2,⋯,n\bm{s}_i=a(\bm{y}_{i-1}, \bm{s}_{i-1}),\quad i=1,2,\cdots,n si=a(yi−1,si−1),i=1,2,⋯,n yi−1\bm{y}_{i-1}yi−1 是前一个位置的输出单词的词向量。输出是
pi=g(si),i=1,2,⋯,n\bm{p}_i = g(\bm{s}_i),\quad i=1,2,\cdots,n pi=g(si),i=1,2,⋯,n ggg 是输出层函数,由线性变换和软最大化函数组成。pi\bm{p}_ipi 表示的是下一个位置单词出现的条件概率。
编码器将其最终状态 hm\bm{h}_mhm 作为整个输入序列的表示传递给解码器。解码器将 hm\bm{h}_mhm 作为解码器的初始状态 s0\bm{s}_0s0,决定其状态序列,以及输出单词序列。
下图是一个用基本模型进行机器翻译的例子。
基本模型用一个中间表示描述整个输入序列,其表示能力有限。RNN Search 模型利用注意力机制在输出序列的每一个位置上产生一个组合的输入序列的中间表示,以解决这个问题。
假设有键-值数据库,存储键-值对数据 {(k1,v1),(k2,v2),⋯,(kn,vn)}\{(\bm{k}_1,\bm{v}_1),(\bm{k}_2,\bm{v}_2),\cdots,(\bm{k}_n,\bm{v}_n)\}{(k1,v1),(k2,v2),⋯,(kn,vn)},其中每一个键-值对的键和值都是实数向量。另有查询(query)q\bm{q}q 也是实数向量。向量 q\bm{q}q 和 ki\bm{k}_iki 的维度相同,向量 ki\bm{k}_iki 和 vi\bm{v}_ivi 的维度一般也相同。考虑从键-值数据库中搜索与查询 q\bm{q}q 相似的键所对应的值。注意力是实现检索的一种计算方法。计算查询 q\bm{q}q 和各个键 ki\bm{k}_iki 的归一化相似度 α(q,ki)\alpha(\bm{q},\bm{k}_i)α(q,ki),以归一化相似度为权重,计算各个值 vi\bm{v}_ivi 的加权平均 v\bm{v}v,将计算结果 v\bm{v}v 作为检索结果返回。
v=∑i=1nα(q,ki)⋅vi\bm{v} = \sum_{i=1}^n \alpha(\bm{q},\bm{k}_i)\cdot \bm{v}_i v=i=1∑nα(q,ki)⋅vi 满足 ∑i=1nα(q,ki)=1\sum_{i=1}^n \alpha(\bm{q},\bm{k}_i)=1 i=1∑nα(q,ki)=1 归一化的权重称作注意力权重,一般通过软最大化计算 (softmax):α(q,ki)=e(q,ki)∑j=1ne(q,kj)\alpha(\bm{q},\bm{k}_i) = \frac{e(\bm{q},\bm{k}_i)}{\sum_{j=1}^n e(\bm{q},\bm{k}_j)} α(q,ki)=∑j=1ne(q,kj)e(q,ki) 其中 e(q,ki)e(\bm{q},\bm{k}_i)e(q,ki) 是查询 q\bm{q}q 和 ki\bm{k}_iki 的相似度。相似度计算可以有多种方法,包括加法注意力和乘法注意力。乘法注意力要求查询和键向量的维度相同,而加法注意力没有这个要求。乘法注意力比加法注意力计算效率更高。
加法注意力使用一层神经网络计算相似度:
e(q,ki)=σ(w⊤⋅[q;ki]+b)e(\bm{q},\bm{k}_i) = \sigma(\bm{w}^\top \cdot [\bm{q};\bm{k}_i]+b) e(q,ki)=σ(w⊤⋅[q;ki]+b) 其中 [;][;][;] 表示向量的拼接。
乘法注意力使用内积或尺度变换的内积计算相似度:
e(q,ki)=q⊤⋅kie(\bm{q},\bm{k}_i) = \bm{q}^\top \cdot \bm{k}_i e(q,ki)=q⊤⋅ki e(q,ki)=q⊤⋅kide(\bm{q},\bm{k}_i) = \frac{\bm{q}^\top \cdot \bm{k}_i}{\sqrt{d}} e(q,ki)=dq⊤⋅ki 其中 ddd 是向量 q\bm{q}q 和 ki\bm{k}_iki 的维度。尺度变换保证相似度的取值在一定范围内,避免学习时发生梯度消失。
我们在 PyTorch 中分别实现这两种注意力。
加法注意力:
import torch
from torch import nnclass AdditiveAttention(nn.Module):def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super(AdditiveAttention, self).__init__(**kwargs)self.W_k = nn.Linear(key_size, num_hiddens, bias=False)self.W_q = nn.Linear(query_size, num_hiddens, bias=False)self.w_v = nn.Linear(num_hiddens, 1, bias=False)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values):queries, keys = self.W_q(queries), self.W_k(keys)features = queries.unsqueeze(2) + keys.unsqueeze(1)features = torch.tanh(features)scores = self.w_v(features).squeeze(-1)self.attention_weights = nn.functional.softmax(scores)return torch.bmm(self.dropout(self.attention_weights), values)
乘法注意力:
class DotProductAttention(nn.Module):def __init__(self, dropout, **kwargs):super(DotProductAttention, self).__init__(**kwargs)self.dropout = nn.Dropout(dropout)def forward(self, queries, keys, values, valid_lens=None):d = queries.shape[-1]scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)self.attention_weights = nn.functional.softmax(scores)return torch.bmm(self.dropout(self.attention_weights), values)
RNN Search 模型对基本模型进行两个大的改动。用双向 LSTM 实现编码器,用注意力实现从编码器到解码器的信息传递。
编码时,正向 LSTM 的状态是
hj(1)=a(xj,hj−1(1)),j=1,2,⋯,m\bm{h}_j^{(1)} = a(\bm{x}_j, \bm{h}_{j-1}^{(1)}),\quad j=1,2,\cdots, m hj(1)=a(xj,hj−1(1)),j=1,2,⋯,m 反向 LSTM 的状态是
hj(2)=a(xj,hj+1(2)),j=m,m−1,⋯,1\bm{h}_j^{(2)} = a(\bm{x}_j, \bm{h}_{j+1}^{(2)}),\quad j=m,m-1,\cdots, 1 hj(2)=a(xj,hj+1(2)),j=m,m−1,⋯,1 在各个位置对正向和反向状态进行拼接,得到各个位置的状态,也就是中间表示:
hj=[hj(1);hj(2)],j=1,2,⋯,m\bm{h}_j=[\bm{h}_j^{(1)};\bm{h}_j^{(2)}],\quad j=1,2,\cdots, m hj=[hj(1);hj(2)],j=1,2,⋯,m
解码器使用单向 LSTM,解码基于已生成的输出序列,状态是
si=a(yi−1,si−1,ci),i=1,2,⋯,n\bm{s}_i = a(\bm{y}_{i-1}, \bm{s}_{i-1}, c_i),\quad i=1,2,\cdots, n si=a(yi−1,si−1,ci),i=1,2,⋯,n 这里 cic_ici 是当前位置的上下文向量(context vector),上下文向量表示在当前位置的注意力计算结果。在解码器的每一个位置,通过加法注意力计算上下文向量。注意力的查询是前一个位置的状态 si−1\bm{s}_{i-1}si−1,键和值相同,是编码器各个位置的状态 hj\bm{h}_jhj。上下文向量是
ci=∑j=1mαijhj,i=1,2,⋯,n\bm{c}_i = \sum_{j=1}^m \alpha_{ij} \bm{h}_j, \quad i=1,2,\cdots, n ci=j=1∑mαijhj,i=1,2,⋯,n 其中,αij\alpha_{ij}αij 是注意力权重
αij=exp(eij)∑k=1mexp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^m\exp(e_{ik})} αij=∑k=1mexp(eik)exp(eij) 相似度 eije_{ij}eij 通过一层神经网络计算
eij=σ(w⊤⋅[si−1;hj]+b)e_{ij}=\sigma(\bm{w}^\top \cdot [\bm{s}_{i-1};\bm{h}_j]+b) eij=σ(w⊤⋅[si−1;hj]+b)
传递的上下文向量实际是从输出序列的当前位置看到的输入序列的相关内容。
RNN Search 的最大特点是在输出序列的每一个位置,通过注意力搜索到输入单词序列中的相关内容,和已生成的输出单词序列一起决定下一个位置的单词生成。在机器翻译中,在目标语言中每生成一个单词,都会在源语言中搜索相关的单词,基于搜索得到的单词和目前为止生成的单词做出下一个单词选择的判断。
在每一个位置使用一个动态的中间表示,而不是始终使用一个静态的中间表示。输入序列与输出序列的相关性由单词的内容决定,而不是由单词的位置决定。注意力的参数个数是固定的,可以处理任意长度的输入单词序列。
[1] 《机器学习方法》,李航,清华大学出版社。
[2] 《动手学深度学习》,Aston Zhang, Zachary C. Lipton, Mu Li, and Alexander J. Smola.