1 前言
Transformer模型相較於傳統的RNN模型更適合并行訓練,因此得以在NLP領域得到廣汎應用。但是\(O(n^2)\)空間和時間複雜度是導致Transformer模型消耗顯存和算力的痛點。這是第一個問題。
有一個觀點是“壓縮即智能”。但是我們在Transformer上沒有看到任何“記憶壓縮”的影子。它只是將所有token的特徵排列在那裏,可供後續隨時取用。這是第二個問題,也是我最近常常思考的問題。
從上面兩點出發,我最近著手調研(或者也可以説是考古吧😀)和實驗一些綫性Transformer的工作。
本文是我閱讀Transformers are RNNs[1]這篇文章的筆記。這篇文章有一些有意思的貢獻:
- 從softmax attention這個特例推廣到一般的衡量query和key相似度的方法;
- 提出線性Transformer,它既能像傳統Transformer一樣並行訓練,也能像RNN模型那樣,在推理階段維持常數的空間複雜度和時間複雜度。
2 Transformer的结构
Transformer模型\(T\)可以看做是一個類型為\(\mathbb R^{N\times F}\to \mathbb R^{N\times F}\)的函數,這裡\(N\)是序列長度,\(F\)是特征向量的維度。
一個Transformer有L層,記第\(l\)層為\(T_l\),其結構為 \[ T_l(x) = f_l(A_l(x) + x), \] 其中\(A_l\)是Attention模塊,\(f_l\)是Feed-forward模塊。
在Attention模塊\(A_l\)中,\(x\)先經過線性轉換\(\mathbf W_Q\in \mathbb R^{F\times D},\mathbf W_K\in \mathbb R^{F\times D},\mathbf W_V\in \mathbb R^{F\times M}\),得到\(\mathbf Q, \mathbf K, \mathbf V\)(即query、key和value向量): \[ \begin{aligned} \mathbf Q &= \mathbf x\mathbf W_Q,\\ \mathbf K &= \mathbf x \mathbf W_K,\\ \mathbf V &= \mathbf x\mathbf W_V. \end{aligned} \] 接著,Attention模塊的輸出為 \[ \mathbf A_l(x) = \mathbf V' = \text{softmax}\left(\frac{\mathbf Q\mathbf K^T}{\sqrt{D}}\right)V. \tag{1}\] 可以看到\(V'\)實際上是\(V\)的加權和,其中每個位置特征的權重由\(Q\)和\(K\)計算得到。
公式 1使用softmax函數計算權重,這實際是下面的這個函數的一個特例: \[ \mathbf V'_i = \frac{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)V_j}{\sum_{j=1}^N \text{sim}(\mathbf Q_i, \mathbf K_j)}. \tag{2}\] softmax函數等同於將上式中的\(\text{sim}\)函數定義為\(\exp(\frac{\mathbf Q_i^T\mathbf K_j}{\sqrt D})\)
3 線性注意力
假設對於某\(\text{sim}\)函數,存在\(\phi\)函數,使得 \[ \text{sim}(\mathbf Q_i, \mathbf K_j) = \phi(\mathbf Q_i)^T \phi(\mathbf K_j) \] 這樣我們就將Attention模塊線性化了。 利用矩陣乘法的結合律,上式可以改寫為: \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^N \phi(\mathbf K_j) } \tag{3}\] 由於query和key兩兩之間都需要計算相似度,softmax-attention[2]的時間複雜度是\(O(N^2)\);其空間複雜度同理也是\(O(N^2)\),因為需要存儲注意力矩陣用於後續的反向傳播。而線性注意力的時間複雜度是\(O(N)\),空間複雜度也是\(O(N)\)。由公式 3可見,\(\sum_{j=1}^N \phi(\mathbf K_j) \mathbf V_j^T\)和\(\sum_{j=1}^N \phi(\mathbf K_j)\)均可以在計算時可以被緩存和複用。
為了使\(\text{sim}(\cdot)\)是非負的,文章取\(\phi(x)\)為 \[ \phi(x) = \text{elu}(x) + 1 \]
4 因果掩膜
在訓練自回歸模型時,每個位置的token都不能受到未來token的影響。位置\(i\)上的token受到位置\(j\)上的token的影響,當且僅當\(j\leq i\). 這應用於線性Transformer,需要將公式 2改為 \[ \mathbf V'_i=\frac{\sum_{j=1}^i \text{sim}(\mathbf Q_i, \mathbf K_j)\mathbf V_j}{\sum_{j=1}^i \text{sim}(\mathbf Q_i,\mathbf K_j)}. \tag{4}\] 類似前面的推理,上式可改寫為 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T }{ \phi(\mathbf Q_i)^T \sum_{j=1}^i \phi(\mathbf K_j) }, \tag{5}\] 引入變量\(S_i\)和\(Z_i\)如下: \[ \begin{aligned} \mathbf S_i &= \sum_{j=1}^i \phi(\mathbf K_j) \mathbf V_j^T,\\ \mathbf Z_i &= \sum_{j=1}^i \phi(\mathbf K_j). \end{aligned} \] 公式 5可以改寫為 \[ \mathbf V'_i = \frac{ \phi(\mathbf Q_i)^T\mathbf S_i }{ \phi(\mathbf Q_i)^T \mathbf Z_i }. \tag{6}\] 注意\(\mathbf S_i,\mathbf Z_i\)可以由\(\mathbf S_{i-1},\mathbf Z_{i-1}\)計算得到,時間複雜度為常數。
5 梯度計算
如果根據公式 6實現簡單的深度學習模型,那麼為了梯度計算,我們需要緩存所有的\(S_i, Z_i\),這將帶來很大的空間複雜度。文章提出將對公式 5分子的求導實現為一種基於累加和(cumulative sum)的方法,實現前向傳播和反向傳播時的線性時間複雜度和常數空間複雜度。
为了简便,本节的推导假设\(\mathbf Q, \mathbf K\)中已经包含了\(\phi\)函数。
假设\(\overline{\mathbf V}_i\)是公式 5的分子,即 \[ \overline{\mathbf V}_i = \mathbf Q_i^T \sum_{j=1}^i \mathbf K_j \mathbf V_j^T. \tag{7}\] 设\(L\)为损失,已知\(\overline{\mathbf V}_i\)和\(L\),則\(L\)对\(\mathbf Q, \mathbf K, \mathbf V\)的梯度分别为 \[ \nabla_{\mathbf Q_i}L = (\nabla_{\overline{\mathbf V}_i} L)(\sum_{j=1}^i \mathbf K_j \mathbf V_j^T)^T, \tag{8}\] \[ \nabla_{\mathbf K_i}L = \left(\sum_{j=i}^N\mathbf Q_j (\nabla_{\overline{\mathbf V}_j}L)^T\right)\mathbf V_i, \tag{9}\] \[ \nabla_{\mathbf V_i}L = \left(\sum_{j=i}^N \mathbf Q_j (\nabla_{\overline{V}_j}L)^T\right)^T \mathbf K_i, \tag{10}\] 其中\(\mathbf Q\in\mathbb R^{N\times D},\mathbf K\in\mathbb R^{N\times D}, \mathbf V\in \mathbb R^{N\times M}\).
文章只考虑了分子的梯度计算。分母、整个分式的梯度计算交给torch
自动处理。
以下是详细推导:
首先我们考虑矩阵中每个元素的计算,将公式 7中的矩阵、向量记号去除,得到 \[ \overline{V}_{ie} = \sum_{d=1}^D Q_{id} \sum_{j=1}^i K_{jd} V_{je} = \sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je}。 \]
于是对于任意的\(Q_{lt}\),我们可以推导出梯度为 \[ \frac{\partial L}{\partial Q_{lt}} = \sum_{e=1}^M \frac{\partial L}{\partial \overline{V}_{le}}\frac{\partial \overline{V}_{le}}{\partial Q_{lt}} = \sum_{e=1}^M\frac{\partial L}{\partial \overline{V}{le}}(\sum_{j=1}^lK_{jt}V_{je}). \tag{11}\]
將其整理成矩陣形式,得到公式 8
在公式 11中,我們利用了\(\overline{\mathbf V}_l\)只受\(l\)位置的query(即\(\mathbf Q_l\))影響的性質。query只影響當下,而每個key和value都會對未來的計算產生影響。對於key,其梯度的計算方式為: \[ \begin{aligned} \frac{\partial L}{\partial K_{lt}} &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial \overline{V}_{ie}}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial(\sum_{d=1}^D\sum_{j=1}^i Q_{id}K_{jd}V_{je})}{\partial K_{lt}} \\ &= \sum_{e=1}^M\sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} Q_{it} V_{le} \end{aligned} \] 將其整理為矩陣形式,得到公式 9
類似的,對於value,其梯度的計算方式為: \[ \begin{aligned} \frac{\partial L}{\partial V_{lt}} &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\overline{V}_{ie}}{\partial V_{lt}} \\ &= \sum_{e=1}^M \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{ie}} \frac{\partial (\sum_{d=1}^D \sum_{j=1}^i Q_{id} K_{jd} V_{je})}{\partial V_{lt}}\\ &= \sum_{i=l}^N \frac{\partial L}{\partial \overline{V}_{it}} \sum_{d=1}^D Q_{id}K_{ld} \end{aligned} \] 將其整理為矩陣形式,得到公式 10
6 訓練和推理
訓練時,完整的訓練序列是已知的,這允許Transformer模型實現並行的訓練;而受限於計算方式,傳統的RNN模型一般難並行訓練。Transformer模型的每一步推理的時間複雜度是不同的,隨著上下文長度的增加而增加;而RNN模型的時間複雜度是固定的。
文章提出的線性Transformer結合了兩者的優點。
7 Transformer模型是RNN模型的特例
從本文的討論,我們可以明顯的看出,帶因果掩膜的Transformer模型可以視作是RNN模型的特例,即Transformer模型可以看做是一個能維護一個內部狀態(\(\mathbf S_i\)和\(\mathbf Z_i\)),在每次獲得新輸入時更新內部狀態的模型。
8 推薦閲讀
- 《线性Attention的探索:Attention必须有个Softmax吗?》。蘇劍林的文章,非常好的總結分享了各種各樣的綫性Attention,甚至提出了自己的新設計。