本人求职中!大模型 Post-Training 方向 · 长三角、南方地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

【论文精读】Linear Transformer

本文瀏覽次數

前言

Transformer 模型相较于传统的 RNN 模型更适合并行训练,因此得以在 NLP 领域得到广泛应用。但是它也存在一些问题:

  1. 时间复杂度和空间复杂度\(O(n^2)\) 空间和时间复杂度是导致 Transformer 模型消耗显存和算力的痛点。
  2. 缺少对上下文的压缩:有一个观点是”压缩即智能”。但是我们在 Transformer 上没有看到任何”记忆压缩”的影子。它只是将所有 token 的特征排列在那里,可供后续随时取用。有没有一种办法,能够实现上下文的压缩,用有限的空间管理我们感兴趣的知识呢?

出于对以上问题的兴趣,我著手调研(或者也可以说是考古吧😀)和实验一些线性 Transformer 的工作。

本文是我阅读 Transformers are RNNs [1] 这篇文章的笔记。这篇文章有一些有意思的贡献:

  1. 从 softmax attention 这个特例推广到一般的衡量 query 和 key 相似度的方法;
  2. 提出线性 Transformer,它既能像传统 Transformer 一样并行训练,也能像 RNN 模型那样,在推理阶段维持常数的空间复杂度和时间复杂度。

Transformer 的结构

Transformer 模型 \(T\) 可以看做是一个类型为 \(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\) 的函数,这里 \(N\) 是序列长度,\(F\) 是特征向量的维度。

一个 Transformer 有 \(L\) 层,记第 \(l\) 层为 \(T_l\),其结构为

\[ T_l(x) = f_l(A_l(x) + x), \]

其中 \(A_l\) 是 Attention 模块,\(f_l\) 是 Feed-forward 模块。

在 Attention 模块 \(A_l\) 中,\(x\) 先经过线性转换 \(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到 \(\mathbf Q, \mathbf K, \mathbf V\)(即 query、key 和 value 向量):

\[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \]

接著,Attention 模块的输出为

\[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\]

可以看到 \(V'\) 实际上是 \(V\) 的加权和,其中每个位置特征的权重由 \(Q\)\(K\) 计算得到。

Equation 1 使用 softmax 函数计算权重,这实际是下面的这个函数的一个特例:

\[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\]

softmax 函数等同于将上式中的 \(\text{sim}\) 函数定义为 \(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)

线性注意力

假设对于某 \(\text{sim}\) 函数,存在 \(\phi\) 函数,使得

\[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \]

这样我们就将 Attention 模块线性化了。利用矩阵乘法的结合律,上式可以改写为:

\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\]

由于 query 和 key 两两之间都需要计算相似度,softmax-attention [2] 的时间复杂度是 \(O(N^2)\);其空间复杂度同理也是 \(O(N^2)\),因为需要存储注意力矩阵用于后续的反向传播。而线性注意力的时间复杂度是 \(O(N)\),空间复杂度也是 \(O(N)\)。由 Equation 3 可见,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\)\(\sum_{j=1}^N \phi(\mathbf K_j)\) 均可以在计算时可以被缓存和复用。

为了使 \(\text{sim}(\cdot)\) 是非负的,文章取 \(\phi(x)\)

\[ \phi(x) = \text{elu}(x) + 1 \]

因果掩膜

在训练自回归模型时,每个位置的 token 都不能受到未来 token 的影响。位置 \(i\) 上的 token 受到位置 \(j\) 上的 token 的影响,当且仅当 \(j\leq i\)。这应用于线性 Transformer,需要将 Equation 2 改为

\[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\]

类似前面的推理,上式可改写为

\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\]

引入变量 \(S_i\)\(Z_i\) 如下:

\[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \]

Equation 5 可以改写为

\[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\]

注意 \(\mathbf S_i, \mathbf Z_i\) 可以由 \(\mathbf S_{i-1}, \mathbf Z_{i-1}\) 计算得到,时间复杂度为常数。

梯度计算

如果根据 Equation 6 实现简单的深度学习模型,那么为了梯度计算,我们需要缓存所有的 \(S_i, Z_i\),这将带来很大的空间复杂度。文章提出将对 Equation 5 分子的求导实现为一种基于累加和(cumulative sum)的方法,实现前向传播和反向传播时的线性时间复杂度和常数空间复杂度。

为了简便,本节的推导假设 \(\mathbf Q, \mathbf K\) 中已经包含了 \(\phi\) 函数。

假设 \(\overline{\mathbf V}_i\)Equation 5 的分子,即

\[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\]

\(L\) 为损失,已知 \(\overline{\mathbf V}_i\)\(L\),则 \(L\)\(\mathbf Q, \mathbf K, \mathbf V\) 的梯度分别为

\[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\]

\[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\]

\[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\]

其中 \(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\)

注记

文章只考虑了分子的梯度计算。分母、整个分式的梯度计算交给 torch 自动处理。

以下是详细推导:

首先我们考虑矩阵中每个元素的计算,将 Equation 7 中的矩阵、向量记号去除,得到

\[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}. \]

于是对于任意的 \(Q_{lt}\),我们可以推导出梯度为

\[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}_{le}}(\sum_{j=1}^l K_{jt} V_{je}). \tag{11}\]

将其整理成矩阵形式,得到 Equation 8

Equation 11 中,我们利用了 \(\overline{\mathbf V}_l\) 只受 \(l\) 位置的 query(即 \(\mathbf Q_l\))影响的性质。query 只影响当下,而每个 key 和 value 都会对未来的计算产生影响。对于 key,其梯度的计算方式为:

\[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \]

将其整理为矩阵形式,得到 Equation 9

类似的,对于 value,其梯度的计算方式为:

\[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \]

将其整理为矩阵形式,得到 Equation 10

训练和推理

训练时,完整的训练序列是已知的,这允许 Transformer 模型实现并行的训练;而受限于计算方式,传统的 RNN 模型一般难并行训练。Transformer 模型的每一步推理的时间复杂度是不同的,随著上下文长度的增加而增加;而 RNN 模型的时间复杂度是固定的。

文章提出的线性 Transformer 结合了两者的优点。

Transformer 模型是 RNN 模型的特例

从本文的讨论,我们可以明显的看出,带因果掩膜的 Transformer 模型可以视作是 RNN 模型的特例,即 Transformer 模型可以看做是一个能维护一个内部状态(\(\mathbf S_i\)\(\mathbf Z_i\)),在每次获得新输入时更新内部状态的模型。

推荐阅读

参考文献

[1]
A. Katharopoulos, A. Vyas, N. Pappas, and F. Fleuret, “Transformers are RNNs: Fast autoregressive transformers with linear attention,” in Proceedings of the 37th international conference on machine learning, 2020. Available: https://arxiv.org/abs/2006.16236
[2]
A. Vaswani et al., “Attention is all you need,” Advances in Neural Information Processing Systems, vol. 30, 2017, Available: https://arxiv.org/abs/1706.03762
By @執迷 in
Tags : #Transformer, #Linear Transformer, #線性注意力, #RNN, #深度學習,