本人求职中!大模型 Post-Training 方向 · 长三角、福建地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

KTO算法的學習和推導

本文瀏覽次數

之前我在一個實際項目裡面負責一個LLM的偏好的對齊。神奇的是這個項目只有單個樣本的拒絕/採樣信號,沒有pair-wise的數據。我還是第一次遇到這種情況。

很自然的,我想到要試試KTO的效果。KTO是一種只需要單個樣本偏好數據的偏好對齊算法。

這篇文章記錄了我阅读KTO這篇論文的筆記,以及對KTO背後思想的一些理解。

HALO

KTO提出了human-aware losses(HALO)的分類學概念。損失函數可以分爲符合HALO和不符合HALO規範的。雖然沒有辦法證明HALO類的損失函數一定優於非HALO的,但是論文表示,從經驗上看,HALO損失函數的效果一般更好。

HALO的定義: 假設\(\theta\)表示模型\(\pi_\theta:\mathcal X\rightarrow\mathcal P(\mathcal Y)\)的可訓練參數。\(\pi_\text{ref}\)是基準模型,\(l:\mathcal Y\rightarrow \mathbb R^+\)是一個歸一化係數,\(r_\theta(x, y)=l(y) \log [\pi_\theta(y | x)/\pi_\text{ref}(y|x)]\)是隱含的獎勵函數。數據\((x, y)\)的“人類價值”應該可以形式化為: \[ v(r_\theta(x, y) - \mathbb E_Q[r_\theta(x, y')]), \] 其中\(Q(Y'|x)\)\(\mathcal Y\)上的一個基準點數據的分佈,\(v:\mathbb R\rightarrow \mathbb R\)是非降的,在\((0, \infty)\)內為凹的函數。

函數\(f\)如果滿足以下性質,就可以視為基於v的“human-aware的損失函數”: \[ f(\pi_\theta, \pi_\text{ref})=\mathbb E_{x, y\sim\mathcal D}[a_{x, y} v(r_\theta(x, y) - \mathbb E_Q[r_\theta(x, y')])] + C_{\mathbfcal D} \] 其中 \(a_{x, y}\in \{-1, +1\}\). \(\mathcal D\)是數據集,\(C_{\mathcal D}\in\mathbb R\)是由數據集決定的常數。

如何判斷一個對齊方法是否屬於HALO

CSFT不是HALO

CSFT會在prompt中注入control token,例如<good><bad>,與相應的回復數據拼接在一起,構造得到訓練數據。推理的時候,一律採用<good>這個control token來獲得回復。

所以CSFT非常類似普通的SFT。

\[ L_\text{CSFT} = -\log \pi_\theta(y|x, c), \] 其中\(c=<\text{good}>\)或者\(<\text{bad}>\)

假設\(L_\text{CSFT}\)符合HALO的定義,那麼存在 \(r_\theta(x, y)=l(y)\log \frac{\pi_\theta(y | x)}{\pi_\text{ref}(y|x)}\)\(R(x) = E_Q[r_\theta(x, y')]\),使得 \[ L_\text{CSFT} = a v(r_\theta(x, y) - R(x)) + C \] 代入HALO對\(r_\theta\)的定義,上式展開得到 \[ \begin{aligned} L_\text{CSFT} = a v \left(l(y)\log \pi_\theta(y|x) - l(y)\log \pi_\text{ref}(y|x)- R(x) \right) + C\\ \end{aligned} \] 要看CSFT能否構成HALO,就是看能否找到合適的\(a, v, l(y), R(x), C\)使得下面的等式成立: \[ -\log \pi_\theta(y|x, c) = a v \left(l(y)\log \pi_\theta(y|x) - l(y)\log \pi_\text{ref}(y|x)- R(x) \right) + C \]

我們比較等式兩邊的依賴:

  • 左邊:只依賴\(\pi_\theta\)
  • 右邊:
    • \(\log \pi_\theta(y|x)\)
    • \(\log \pi_\text{ref}(y|x)\)
    • \(R(x)\)

為了讓等式對所有的y成立,\(R(x)\)就必須包含\(\pi_\text{ref}\)。除非\(\pi_\text{ref}\)是不依賴\(y\)的均勻分佈,否則是無法成功構造的。

因此論文得出結論:CSFT不是一種HALO。

SLiC不是HALO

SLiC(Sequence Likelihood Calibration)的損失函數如下: \[ \begin{aligned} &L_\text{cal} (\pi_\theta)= \mathbb E_{x, y_w, y_l\sim D}\left[\max\left(0, \delta - \log\frac{\pi_\theta(y_w|x)}{\pi_\theta(y_l|x)}\right)\right]\\ &L_\text{reg} (\pi_\theta, \pi_\text{ref}) = \mathbb E_{x\sim D, y\sim\pi_\text{ref}(x)} [-\log \pi_\theta(y|x)]\\ &L_\text{SLiC}(\pi_\theta, \pi_\text{ref}) = L_\text{cal} (\pi_\theta) + \lambda_\text{reg}L_\text{reg}(\pi_\theta, \pi_\text{ref}) \end{aligned} \] 分析SLiC的損失函數,可以看到 \[ \begin{aligned} &max(0, \delta - \log \frac{\pi_\theta(y_w|x)}{\pi_\theta(y_l|x)})\\ =&max(0, \delta - (\log \pi_\theta(y_w|x) - \log\pi_\theta(y_l|x))) \end{aligned} \] 這部分實際上希望模型輸出\(y_w\)的對數似然大於輸出\(y_l\)的對數似然,但差距不要差太大。

\(L_\text{reg}\)做的事情就是讓\(\pi_\text{ref}\)負責\(y\)的採樣,用\(\text{ref}\)模型產生的數據約束當前訓練中的模型,約束模型不要偏離基礎模型太遠。

和CSFT的推導過程類似,由於\(\pi_\text{ref}\)只用來採樣數據,但不直接構成損失函數的一項,SLiC也不是HALO。

DPO是HALO

DPO的损失函数是 \[ L_\text{DPO}(\pi_\theta, \pi_\text{ref}) = \mathbb E_{x, y_w, y_l}\left[ -\log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)} \right) \right], \]\(l(y)=\beta, r_\theta = \beta \log(\frac{\pi_\theta(y|x)}{\pi_\text{ref}(y|x)}), v(\cdot) = \log \sigma(\cdot)\)\(Q\)是一个将全部概率质量集中在\((x, y_l)\)上的分布。\(a_{x, y}=-1\). 通過這樣的構造,我們可以驗證DPO符合HALO的定義。

PPO-Clip是HALO

\[ L_\text{PPO~(offline)} = - \mathbb E _{x, y, t\sim D}\left[ \min(q_\theta A(x:y_{<t}, y_t), \text{clip}(q_\theta, 1 - \epsilon, 1 + \epsilon)A(x:y_{<t}, y_t)) \right] \] 其中\(q_\theta = \frac{\pi_\theta(y_t|x:y_{<t})}{\pi_\text{ref}(y_t|x:y_{<t})}\)是token级别的概率比值。

\(A(x:y_{<t}, y_t)\)是token级别的优势函数,可以表示为\(Q^\pi(x:y_{<t}, y_t) - V^\pi(x:y_{<t})\),即动作-价值函数和价值函数的差值。因为\(V^\pi(x:y_{<t}) = \mathbb E_{y\sim \pi}Q^\pi(x:y_{<t}, y)\),因此基准点分布(reference distribution)就是policy本身。

那么根据HALO的定义,\(r_\theta\)可以构造为\(q_\theta Q^\pi(x:y_{<t}, t)\)。这里需要不失一般性地,假设\(Q^\pi\)是非负的,因为\(Q^\pi\)总是可以加上一个正数而不改变优势函数。这意味着\(\exists u \geq 1, q_\theta Q^\pi(x:y_{<t}, y)=\log u = \log \hat \pi_\theta(x:y_{<t}, y) / \hat \pi_\text{ref}(x:y_{<t}, y)\),其中\(\hat \pi_\theta, \hat \pi_\text{ref}\)是隐含的策略分布和参考分布。 \(q_\theta A = r_\theta - z_0, v(q_\theta A) = min(q_\theta A, A(1 + \text{sign}(q_\theta A)\epsilon)), a_{x, y}=-1\)就能完成构造。

重新梳理一遍:

\[ \begin{aligned} & v(r_\theta(x, y)) - \mathbb E_Q[r_\theta(x, y')] & \\ = & v(q_\theta Q^\pi - \mathbb E_Q[q_\theta Q^\pi]) & (\text{令} r_\theta = q_\theta Q^\pi)\\ = & v(q_\theta Q^\pi - \mathbb E_{y\sim\pi_\text{ref}}[q_\theta Q^\pi]) & (\text{基准分布就是策略本身})\\ = & v(q_\theta Q^\pi - q_\theta V) & (\text{因为}\mathbb E_{\pi_\text{ref}}[q_\theta Q^\pi] = \mathbb E_{\pi_\theta} [Q^{\pi_\theta}] = V^{\pi_\theta})\\ = & v(q_\theta A) & \text{因为}(A = Q^{\pi_\theta} - V^{\pi_\theta}) \\ = & min(q_\theta A, \text{clip}(q_\theta, 1-\epsilon, 1+\epsilon) A) & \text{代入}v\text{的构造} \end{aligned}\\ \]

Note

根据论文的证明,\(r_\theta\)的构造实际对应着隐含的策略\(\hat \pi_\theta\)\(\hat \pi_\text{ref}\)。实际上这两个策略并不是真实存在的,或者可以直接由真实的策略构造出来的。因此这里的证明与其说是说明“对于PPO所优化的策略”而言,PPO是一种HALO,倒不如说“存在一种隐含的策略”,对于它来说PPO可以解释为HALO。

KTO

前文中已經羅列了一些偏好對齊方法,并且判斷他們是否屬於HALO。接著論文作者提出了自己的設計:Kahneman-Tversky Optimization(KTO)。Kahneman和Tversky認爲人類價值函數可以用如下的算式建模: \[ v(z; \lambda ,\alpha, z_0) = \left\{ \begin{aligned} & (z - z_0)^\alpha & \text{if} z \geq z_0 \\ & -\lambda (z_0 - z) ^\alpha & \text{if} z < z_0 \end{aligned} \right. \] 但是這樣的設計在訓練模型的時候會遇到數值穩定性問題,因此作者將指數函數替換爲logistic函數。

爲了模擬人類的“損失厭惡”傾向,KTO引入了\(\beta \in \mathbb R^+\)參數。\(\beta\)越大,價值函數越容易飽和,對應著人類在正收益時的損失厭惡,在面臨損失時又傾向於接受風險。

原模型中的\(\lambda\)參數被分化爲\({\lambda_D, \lambda_U}\)兩個參數,用於分別控制desirale和undesirable兩種數據的權重。

KTO函數設計爲: \[ L_\text{KTO} (\pi_\theta, \pi_\text{ref}) = \mathbb E_{x, y\sim D}[\lambda_y - v(x, y)], \] 其中 \[ \begin{aligned} r_\theta(x, y) & = \log \frac{\pi_\theta(y|x)}{\pi_\text{ref}(y|x)} \\ z_0 &= \text{KL}(\pi_\theta(y'|x)\Vert \pi_\text{ref}(y'|x))\\ v(x, y) &= \left \{ \begin{aligned} \lambda_D \sigma(\beta(r_\theta(x, y) - z_0))~~~~& \text{if} ~y\sim y_\text{desirable} | x \\ \lambda_U \sigma(\beta(z_0 - r_\theta(x, y)))~~~~& \text{if} ~y\sim y_\text{undesirable} | x \end{aligned} \right. \end{aligned} \]

其中\(z_0\)不參與梯度反向傳播,以保持訓練過程穩定。

\(z_0\)的計算本質是估計一個KL散度。理論上需要用蒙特卡洛方法從\(\pi_\theta\)中采樣\(y\),再平均。采樣過程是很慢的,代價很大。所以作者這裏用了一個取巧的方法。對一個同一個batch内的一組數據\(\{(x_1, y_1), (x_2, y_2), \dots, (x_m, y_m)\}\),KTO首先讓\(x\)\(y\)錯位組合變成\(\{(x_1, y_2), (x_2, y_3), \dots, (x_m, y_1)\}\). KTO使用以下的公式估計\(z_0\): \[ \begin{aligned} &\hat z_0 = \max\left( 0, \frac{1}{m}\sum_{i\leq i\lt m} \log \frac{\pi_\theta(y_i|x_i)}{\pi_\text{ref}(y_j|x_i)} \right),\\ &\text{where} ~j = (i + 1) \mod m \end{aligned} \]

也就是用訓練集中固有的\(y\)來避免重新采樣。但是\(y_j\)\(x_i\)完全沒有關係怎麽辦?作者認爲這是biased,但是方便啊。估計需要使用\(\max(0, \cdot)\)保證估計得到的KL散度是非負的。

\(\hat z_0\)對於每個batch計算一次,在batch内是共享的。

一些思考

KTO其實依賴一個比較強的假設——損失厭惡。但是這在我的實際項目裡是不成立的。在我的項目中,被拒絕的數據不一定是壞數據,被採納的數據不一定是好數據。噪聲太大了。我不得不將undesired_weight調整到一個比較低的水平,才能跑起來,否則訓練就不穩定。

在“偏好”這個信號有很強的噪聲時,該如何利用偏好數據蘊含的信息進行訓練呢?這個問題沒法很好地用KTO解決,但很有實際研究價值。

By @執迷 in
Tags :