DPO(Direct Preference Optimization)是一种不需要显式训练奖励模型、也不需要强化学习(RL)的偏好对齐方法。它直接将人类偏好信号融入策略模型的损失函数中,通过一个闭式推导绕过了传统 RLHF 中的 PPO 步骤。
本文从 Bradley-Terry 模型出发,逐步推导 DPO 的损失函数,并讨论其局限性与改进方向。
BT模型
Bradley-Terry模型约定人类偏好分布可以表示为如下形式: \[ p^*(y_1 \succ y_2|x)=\frac{ \exp(r^*(x, y_1)) }{ \exp(r^*(x, y_1)) + \exp(r^*(x, y_2 )) } \] 其中\(r^*(x,y)\)是一个隐式的奖励函数。
设我们训练一个模型\(r_\phi\)用于近似\(r^*\)。将上面的问题视为二分类问题,可以得到如下对数似然损失: \[ \mathcal L_R(r_\phi, \mathcal D)= - \mathbb E_{(x, y_w, y_l)\sim \mathcal D}\left[ \log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l)) \right] \] 其中\(\sigma(z)=\frac{1}{1 + e^{-z}}\).
注意到 \[ \begin{aligned} p^*(y_1 \succ y_2|x) &=\frac{ 1 } { 1 + \exp(r^*(x, y_2) - r^*(x, y_1)) } \\ &= \sigma(r^*(x, y_1) - r^*(x, y_2)) \end{aligned} \] 因此\(\mathcal L_R(r_\phi, \mathcal D)\)是显然的。
RL的目标函数
通常,我们考虑以下RLHF的优化目标 \[ \max_{\pi_\theta} \mathbb E_{x\sim D, y\sim \pi_\theta(y|x)}[r_\phi(x, y)] -\beta \mathbb D_\text{KL}\left[ \pi_\theta(y|x) \Vert \pi_\text{ref}(y|x) \right] \] 对于离散的token sampling,这个expectation无法直接用普通反向传播对样本做微分,故通常用policy gradient类的强化学习方法来优化它,并将奖励函数设为 \[ r(x, y) = r_\phi(x, y) - \beta (\log \pi_\theta(y|x) - \log \pi_\text{ref}(y | x)). \]
小技巧:将KL项吸收到reward里。 \[ \mathbb D_\text{KL} (\pi_\theta \Vert \pi_\text{ref}) = \mathbb E_{y\sim\pi_\theta} \left[ \log \pi_\theta(y|x) - \log \pi_\text{ref}(y|x) \right] \] 所以原本的目标可以写成: \[ \max_{\pi_\theta} \mathbb E_{y\sim\pi_\theta}\left[ r_\phi(x, y) - \beta(\log \pi_\theta(y|x) - \log \pi_\text{ref}(y|x)) \right] \] 这就实现了将KL项吸收到了reward里。
奖励函数和最优策略之间的关系
在这个函数下,我们可以推导出策略模型的最优解为: \[ \pi^*(y|x) = \frac{1}{Z(x)}\pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y)) \] 其中 \[ Z(x) = \sum_y \pi_\text{ref}(y|x) \exp(\frac{1}{\beta} r(x, y)) \]
推导过程如下: \[ \begin{aligned} &\max_\pi \mathbb E_{x\sim\mathcal D, y\sim \pi} \left[r(x, y)\right] - \beta \mathbb D_\text{KL}\left[ \pi(y|x)\Vert \pi_\text{ref}(y|x) \right]\\ &=\max_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[r(x, y) - \beta \log\frac{\pi(y|x)}{\pi_\text{ref}(y|x)}\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[ \log \frac{\pi(y|x)}{\pi_\text{ref}(y|x)} - \frac{1}{\beta}r(x, y)\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \mathbb E_{y\sim \pi(y|x)} \left[\log\frac{\pi(y|x)}{\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))} - \log Z(x)\right]\\ &= \min_\pi \mathbb E_{x\sim \mathcal D} \left( \mathbb D_\text{KL} (\pi(y|x)\Vert\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))) - \log Z(x) \right)\\ \end{aligned} \] 这里我们需要\(Z(x)\)的取值能够使得\(\frac{1}{Z(x)} \pi_\text{ref}(y|x)\exp(\frac{1}{\beta} r(x, y))\)能够成为一个合理的分布,因此显然 \[ Z(x) = \sum_y \pi_\text{ref}(y|x)\exp(\frac{1}{\beta}r(x, y)). \] 于是可以看出令目标最小化的最优策略为 \[ \pi^*(y|x) = \frac{1}{Z(x)}\pi_\text{ref} (y|x)\exp(\frac{1}{\beta} r(x, y)) \]
根据这个推导结果,如果已知奖励函数\(r^*(x,y)\)对应的最优策略\(\pi^*(y|x)\),那么就可以反过来,用\(\pi^*\)表示奖励函数: \[ r^*(x, y) =\beta\log \frac{\pi^*(y|x)}{\pi_\text{ref}(y|x)} +\beta \log Z(x) \]
DPO的损失函数
这样一来,就可以直接得到DPO的损失函数:
\[ \begin{aligned} p^*(y_1 \succ y_2|x) &= \frac{ \exp(r^*(x, y_1)) }{\exp(r^*(x, y_1)) + \exp(r^*(x, y_2))}\\ &= \frac{\exp\left(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} +\beta \log Z(x)\right)}{\exp\left(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} +\beta \log Z(x)\right)+\exp\left(\beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)} +\beta \log Z(x)\right)} \\ &= \frac{1}{1 + \exp\left(\beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)} - \beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)}\right)} \\ &= \sigma(\beta\log \frac{\pi^*(y_1|x)}{\pi_\text{ref}(y_1|x)} - \beta\log \frac{\pi^*(y_2|x)}{\pi_\text{ref}(y_2|x)}) \end{aligned} \] 因此DPO的损失函数设计为 \[ \mathcal L_\text{DPO}(\pi_\theta; \pi_\text{ref}) = - \mathbb E_{(x, y_w, y_l)\sim D} \left[ \log \sigma(\beta\log \frac{\pi_\theta(y_w|x)}{\pi_\text{ref}(y_w|x)} - \beta\log \frac{\pi_\theta(y_l|x)}{\pi_\text{ref}(y_l|x)}) \right] \]
DPO的局限性和优化点
本节记录一下我在使用DPO的过程中遇到的一些优化问题和采取的改进策略,仅供参考:
- DPO的损失函数旨在拉大正负例样本之间的距离(和对比学习还蛮像的),但是有可能造成正负例的概率同时增大或减小,从而导致最终的训练结果不稳定。工程上一个简单的修复策略,是在DPO损失的基础上简单增加一个SFT监督损失,如\(\mathcal{L}_{\text{final}} = \mathcal{L}_{\text{DPO}}(x, y_w, y_l) + \lambda \cdot \mathcal{L}_{\text{SFT}}(x, y_w)\);
- 在一个实际项目的数据集中,正例的文本长度常常大于负例的文本长度,训练出来的模型也偏好长文本,而且比训练集文本的长度还长。针对这个问题,我们提出对样本进行采样和丢弃,使得正负例的文本长度均值接近,最终训练出来的模型的回复长度得以恢复正常。