本人求职中!大模型 Post-Training 方向 · 长三角、南方地区优先 · 如果您有兴趣,欢迎联系 zhimi64@foxmail.com 👀 了解更多

DPO算法的學習和推導

DPO(Direct Preference Optimization)是一種不需要顯式訓練獎勵模型、也不需要強化學習(RL)的偏好對齊方法。它直接將人類偏好信號融入策略模型的損失函數中,通過一個閉式推導繞過了傳統 RLHF 中的 PPO 步驟。

本文從 Bradley-Terry 模型出發,逐步推導 DPO 的損失函數,並討論其局限性與改進方向。

BT模型

Bradley-Terry模型約定人類偏好分佈可以表示為如下形式: \[ p^*(y_1 \succ y_2|x)=\frac{ \exp(r^*(x, y_1)) }{ \exp(r^*(x, y_1)) + \exp(r^*(x, y_2 )) } \] 其中\(r^*(x,y)\)是一個隱式的獎勵函數。

設我們訓練一個模型\(r_\phi\)用於近似\(r^*\)。將上面的問題視為二分類問題,可以得到如下對數似然損失: \[ \mathcal L_R(r_\phi, \mathcal D)= - \mathbb E_{(x, y_w, y_l)\sim \mathcal D}\left[ \log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l)) \right] \] 其中\(\sigma(z)=\frac{1}{1 + e^{-z}}\).

注意到 \[ \begin{aligned} p^*(y_1 \succ y_2|x) &=\frac{ 1 } { 1 + \exp(r^*(x, y_2) - r^*(x, y_1)) } \\ &= \sigma(r^*(x, y_1) - r^*(x, y_2)) \end{aligned} \] 因此\(\mathcal L_R(r_\phi, \mathcal D)\)是顯然的。

RL的目標函數

通常,我們考慮以下RLHF的優化目標 \[ \max_{\pi_\theta} \mathbb E_{x\sim D, y\sim \pi_\theta(y|x)}[r_\phi(x, y)] -\beta \mathbb D_\text{KL}\left[ \pi_\theta(y|x) \Vert \pi_\text{ref}(y|x) \right] \] 對於離散的token sampling,這個expectation無法直接用普通反向傳播對樣本做微分,故通常用policy gradient類的強化學習方法來優化它,並將獎勵函數設為 \[ r(x, y) = r_\phi(x, y) - \beta (\log \pi_\theta(y|x) - \log \pi_\text{ref}(y | x)). \]

小技巧:將KL項吸收到reward裡。 \[ \mathbb D_\text{KL} (\pi_\theta \Vert \pi_\text{ref}) = \mathbb E_{y\sim\pi_\theta} \left[ \log \pi_\theta(y|x) - \log \pi_\text{ref}(y|x) \right] \] 所以原本的目標可以寫成: \[ \max_{\pi_\theta} \mathbb E_{y\sim\pi_\theta}\left[ r_\phi(x, y) - \beta(\log \pi_\theta(y|x) - \log \pi_\text{ref}(y|x)) \right] \] 這就實現了將KL項吸收到了reward裡。

獎勵函數和最優策略之間的關係

在這個函數下,我們可以推導出策略模型的最優解為: \[ \pi^*(y|x) = \frac{1}{Z(x)}\pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y)) \] 其中 \[ Z(x) = \sum_y \pi_\text{ref}(y|x) \exp(\frac{1}{\beta} r(x, y)) \]

推導過程如下: \[ \begin{aligned} &\max_\pi \mathbb E_{x\sim\mathcal D, y\sim \pi} \left[r(x, y)\right] - \beta \mathbb D_\text{KL}\left[ \pi(y|x)\Vert \pi_\text{ref}(y|x) \right]\\ &=\max_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[r(x, y) - \beta \log\frac{\pi(y|x)}{\pi_\text{ref}(y|x)}\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[ \log \frac{\pi(y|x)}{\pi_\text{ref}(y|x)} - \frac{1}{\beta}r(x, y)\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[\log\frac{\pi(y|x)}{\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))} - \log Z(x)\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \left( \mathbb D_\text{KL} (\pi(y|x)\Vert\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))) - \log Z(x) \right)\\ \end{aligned} \] 這裡我們需要\(Z(x)\)的取值能夠使得\(\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))\)能夠成為一個合理的分佈,因此顯然 \[ Z(x) = \sum_y \pi_\text{ref}(y|x)\exp(\frac{1}{\beta}r(x, y)). \] 於是可以看出令目標最小化的最優策略為 \[ \pi^*(y|x) = \frac{1}{Z(x)}\pi_\text{ref} (y|x)\exp(\frac{1}{\beta} r(x, y)) \]

根據這個推導結果,如果已知獎勵函數\(r^*(x,y)\)對應的最優策略\(\pi^*(y|x)\),那麼就可以反過來,用\(\pi^*\)表示獎勵函數\[ r^*(x, y) =\beta\log \frac{\pi^*(y|x)}{\pi_\text{ref}(y|x)} +\beta \log Z(x) \]

DPO的損失函數

這樣一來,就可以直接得到DPO的損失函數:

\[ \begin{aligned} p^*(y_1 \succ y_2|x) &= \frac{ \exp(r^*(x, y_1)) }{\exp(r^*(x, y_1)) + \exp(r^*(x, y_2))}\\ &= \frac{\exp\left(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} +\beta \log Z(x)\right)}{\exp\left(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} +\beta \log Z(x)\right)+\exp\left(\beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)} +\beta \log Z(x)\right)} \\ &= \frac{1}{1 + \exp\left(\beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)} - \beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)}\right)} \\ &= \sigma(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} - \beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)}) \end{aligned} \] 因此DPO的損失函數設計為 \[ \mathcal L_\text{DPO}(\pi_\theta; \pi_\text{ref}) = - \mathbb E_{(x, y_w, y_l)\sim D} \left[ \log \sigma(\beta\log \frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)} - \beta\log \frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)}) \right] \]

DPO的局限性和優化點

本節記錄一下我在使用DPO的過程中遇到的一些優化問題和採取的改進策略,僅供參考:

  1. DPO的損失函數旨在拉大正負例樣本之間的距離(和對比學習還蠻像的),但是有可能造成正負例的概率同時增大或減小,從而導致最終的訓練結果不穩定。工程上一個簡單的修復策略,是在DPO損失的基礎上簡單增加一個SFT監督損失,如\(\mathcal{L}_{\text{final}} = \mathcal{L}_{\text{DPO}}(x, y_w, y_l) + \lambda \cdot \mathcal{L}_{\text{SFT}}(x, y_w)\);
  2. 在一個實際項目的數據集中,正例的文本長度常常大於負例的文本長度,訓練出來的模型也偏好長文本,而且比訓練集文本的長度還長。針對這個問題,我們提出對樣本進行採樣和丟棄,使得正負例的文本長度均值接近,最終訓練出來的模型的回覆長度得以恢復正常。
By @執迷 in
Tags :