前言
Transformer 模型相较于传统的 RNN 模型更适合并行训练,因此得以在 NLP 领域得到广泛应用。但是它也存在一些问题:
- 时间复杂度和空间复杂度:\(O(n^2)\) 空间和时间复杂度是导致 Transformer 模型消耗显存和算力的痛点。
- 缺少对上下文的压缩:有一个观点是”压缩即智能”。但是我们在 Transformer 上没有看到任何”记忆压缩”的影子。它只是将所有 token 的特征排列在那里,可供后续随时取用。有没有一种办法,能够实现上下文的压缩,用有限的空间管理我们感兴趣的知识呢?
出于对以上问题的兴趣,我著手调研(或者也可以说是考古吧😀)和实验一些线性 Transformer 的工作。
本文是我阅读 Transformers are RNNs [1] 这篇文章的笔记。这篇文章有一些有意思的贡献:
- 从 softmax attention 这个特例推广到一般的衡量 query 和 key 相似度的方法;
- 提出线性 Transformer,它既能像传统 Transformer 一样并行训练,也能像 RNN 模型那样,在推理阶段维持常数的空间复杂度和时间复杂度。
Transformer 的结构
Transformer 模型 \(T\) 可以看做是一个类型为 \(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\) 的函数,这里 \(N\) 是序列长度,\(F\) 是特征向量的维度。
一个 Transformer 有 \(L\) 层,记第 \(l\) 层为 \(T_l\),其结构为
\[ T_l(x) = f_l(A_l(x) + x), \]
其中 \(A_l\) 是 Attention 模块,\(f_l\) 是 Feed-forward 模块。
在 Attention 模块 \(A_l\) 中,\(x\) 先经过线性转换 \(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到 \(\mathbf Q, \mathbf K, \mathbf V\)(即 query、key 和 value 向量):
\[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \]
接著,Attention 模块的输出为
\[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\]
可以看到 \(V'\) 实际上是 \(V\) 的加权和,其中每个位置特征的权重由 \(Q\) 和 \(K\) 计算得到。
Equation 1 使用 softmax 函数计算权重,这实际是下面的这个函数的一个特例:
\[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\]
softmax 函数等同于将上式中的 \(\text{sim}\) 函数定义为 \(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)。
线性注意力
假设对于某 \(\text{sim}\) 函数,存在 \(\phi\) 函数,使得
\[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \]
这样我们就将 Attention 模块线性化了。利用矩阵乘法的结合律,上式可以改写为:
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\]
由于 query 和 key 两两之间都需要计算相似度,softmax-attention [2] 的时间复杂度是 \(O(N^2)\);其空间复杂度同理也是 \(O(N^2)\),因为需要存储注意力矩阵用于后续的反向传播。而线性注意力的时间复杂度是 \(O(N)\),空间复杂度也是 \(O(N)\)。由 Equation 3 可见,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\) 和 \(\sum_{j=1}^N \phi(\mathbf K_j)\) 均可以在计算时可以被缓存和复用。
为了使 \(\text{sim}(\cdot)\) 是非负的,文章取 \(\phi(x)\) 为
\[ \phi(x) = \text{elu}(x) + 1 \]
因果掩膜
在训练自回归模型时,每个位置的 token 都不能受到未来 token 的影响。位置 \(i\) 上的 token 受到位置 \(j\) 上的 token 的影响,当且仅当 \(j\leq i\)。这应用于线性 Transformer,需要将 Equation 2 改为
\[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\]
类似前面的推理,上式可改写为
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\]
引入变量 \(S_i\) 和 \(Z_i\) 如下:
\[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \]
Equation 5 可以改写为
\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\]
注意 \(\mathbf S_i, \mathbf Z_i\) 可以由 \(\mathbf S_{i-1}, \mathbf Z_{i-1}\) 计算得到,时间复杂度为常数。
梯度计算
如果根据 Equation 6 实现简单的深度学习模型,那么为了梯度计算,我们需要缓存所有的 \(S_i, Z_i\),这将带来很大的空间复杂度。文章提出将对 Equation 5 分子的求导实现为一种基于累加和(cumulative sum)的方法,实现前向传播和反向传播时的线性时间复杂度和常数空间复杂度。
为了简便,本节的推导假设 \(\mathbf Q, \mathbf K\) 中已经包含了 \(\phi\) 函数。
假设 \(\overline{\mathbf V}_i\) 是 Equation 5 的分子,即
\[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\]
设 \(L\) 为损失,已知 \(\overline{\mathbf V}_i\) 和 \(L\),则 \(L\) 对 \(\mathbf Q, \mathbf K, \mathbf V\) 的梯度分别为
\[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\]
\[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\]
\[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\]
其中 \(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\)。
文章只考虑了分子的梯度计算。分母、整个分式的梯度计算交给 torch 自动处理。
以下是详细推导:
首先我们考虑矩阵中每个元素的计算,将 Equation 7 中的矩阵、向量记号去除,得到
\[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}. \]
于是对于任意的 \(Q_{lt}\),我们可以推导出梯度为
\[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}_{le}}(\sum_{j=1}^l K_{jt} V_{je}). \tag{11}\]
将其整理成矩阵形式,得到 Equation 8。
在 Equation 11 中,我们利用了 \(\overline{\mathbf V}_l\) 只受 \(l\) 位置的 query(即 \(\mathbf Q_l\))影响的性质。query 只影响当下,而每个 key 和 value 都会对未来的计算产生影响。对于 key,其梯度的计算方式为:
\[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \]
将其整理为矩阵形式,得到 Equation 9。
类似的,对于 value,其梯度的计算方式为:
\[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \]
将其整理为矩阵形式,得到 Equation 10。
训练和推理
训练时,完整的训练序列是已知的,这允许 Transformer 模型实现并行的训练;而受限于计算方式,传统的 RNN 模型一般难并行训练。Transformer 模型的每一步推理的时间复杂度是不同的,随著上下文长度的增加而增加;而 RNN 模型的时间复杂度是固定的。
文章提出的线性 Transformer 结合了两者的优点。
Transformer 模型是 RNN 模型的特例
从本文的讨论,我们可以明显的看出,带因果掩膜的 Transformer 模型可以视作是 RNN 模型的特例,即 Transformer 模型可以看做是一个能维护一个内部状态(\(\mathbf S_i\) 和 \(\mathbf Z_i\)),在每次获得新输入时更新内部状态的模型。
推荐阅读
- 《线性Attention的探索:Attention必须有个Softmax吗?》。苏剑林的文章,非常好的总结分享了各种各样的线性 Attention,甚至提出了自己的新设计。