15.6 KL 散度深度解析
在前几节中,我们反复看到一个关键约束项出现在 RLHF 的目标函数里:
这个 KL 散度(Kullback-Leibler Divergence) 惩罚项看似只是一个正则化技巧,实则蕴含着深刻的统计学原理。它的方向选择(谁在前、谁在后)直接决定了优化行为的根本性差异——这正是本节要彻底讲清楚的问题。
15.6.1 KL 散度回顾 [必读]
给定两个概率分布
两个关键性质需要牢记:
- 非负性:
,等号成立当且仅当 。 - 不对称性:
,因此 KL 散度不是距离度量。
不对称性的直觉可以这样理解:前面的分布负责采样(决定在哪些点上"看"),后面的分布负责被比较(决定在这些点上"差多远")。
15.6.2 Forward KL vs Reverse KL:不对称性的行为差异 [必读]
假设真实分布
Forward KL:最小化
采样来自
Reverse KL:最小化
采样来自
下表总结了这一核心对比:
| 特征 | Forward KL: | Reverse KL: |
|---|---|---|
| 采样来源 | ||
| 严厉惩罚 | ||
| 行为倾向 | 覆盖 | 锁定 |
| 别名 | Mode-covering(模式覆盖) | Mode-seeking(模式追寻) |
| 拟合结果 | 宽而平(可能过于分散) | 窄而尖(可能遗漏模式) |
表 15-6:Forward KL 与 Reverse KL 的行为对比。
以下代码可视化了这一关键差异,直观展示单峰高斯去拟合双峰混合分布时的不同行为:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import minimize
# 真实分布 P:双峰混合高斯
def p_pdf(x):
return 0.5 * norm.pdf(x, loc=-2, scale=0.8) + 0.5 * norm.pdf(x, loc=2, scale=0.8)
# 近似分布 Q:单峰高斯,用 (mu, sigma) 参数化
def q_pdf(x, mu, sigma):
return norm.pdf(x, loc=mu, scale=sigma)
# Forward KL: D_KL(P || Q) — 用数值积分近似
def forward_kl(params):
mu, log_sigma = params
sigma = np.exp(log_sigma)
x = np.linspace(-6, 6, 2000)
p = p_pdf(x)
q = q_pdf(x, mu, sigma) + 1e-10
# 只在 p > 0 的地方计算
mask = p > 1e-10
return np.trapz(p[mask] * np.log(p[mask] / q[mask]), x[mask])
# Reverse KL: D_KL(Q || P) — 用数值积分近似
def reverse_kl(params):
mu, log_sigma = params
sigma = np.exp(log_sigma)
x = np.linspace(-6, 6, 2000)
p = p_pdf(x) + 1e-10
q = q_pdf(x, mu, sigma)
mask = q > 1e-10
return np.trapz(q[mask] * np.log(q[mask] / p[mask]), x[mask])
# 优化
res_fwd = minimize(forward_kl, x0=[0.0, np.log(1.0)], method='Nelder-Mead')
res_rev = minimize(reverse_kl, x0=[1.5, np.log(0.5)], method='Nelder-Mead')
mu_fwd, sigma_fwd = res_fwd.x[0], np.exp(res_fwd.x[1])
mu_rev, sigma_rev = res_rev.x[0], np.exp(res_rev.x[1])
# 绘图
x = np.linspace(-6, 6, 500)
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
for ax, mu, sigma, title, color in [
(axes[0], mu_fwd, sigma_fwd, f'Forward KL (mode-covering)\n$\\mu$={mu_fwd:.2f}, $\\sigma$={sigma_fwd:.2f}', 'royalblue'),
(axes[1], mu_rev, sigma_rev, f'Reverse KL (mode-seeking)\n$\\mu$={mu_rev:.2f}, $\\sigma$={sigma_rev:.2f}', 'crimson'),
]:
ax.fill_between(x, p_pdf(x), alpha=0.3, color='gray', label='P (true)')
ax.plot(x, q_pdf(x, mu, sigma), linewidth=2, color=color, label='Q (approx)')
ax.set_title(title, fontsize=13)
ax.legend(fontsize=11)
ax.set_xlabel('x')
ax.set_ylabel('density')
plt.tight_layout()
plt.savefig('forward_vs_reverse_kl.png', dpi=150)
plt.show()运行上述代码可以清楚看到:Forward KL 的结果是一个均值接近 0、方差很大的宽高斯,"铺开"覆盖两个峰;而 Reverse KL 的结果是一个紧紧贴住右峰(或左峰,取决于初始值)的窄高斯,完全忽略了另一个峰。
15.6.3 Forward KL 在经典强化学习中的应用 [选读]
经典策略优化算法 TRPO(Trust Region Policy Optimization) 使用的 KL 约束是:
注意这里是
为什么 TRPO 选择 Forward KL? 原因在于 TRPO 的更新建立在旧策略的样本之上。训练数据来自
15.6.4 Reverse KL 在 RLHF 中的三重理由 [必读]
在 RLHF 中,KL 惩罚项的形式为:
这里是
理由一:采样便利性。 训练时的序列是从当前策略
这个期望可以直接用当前策略的采样来做蒙特卡洛估计,不需要额外的重要性采样修正。在 token 级别,每个 token 的 KL 贡献就是:
这与 PPO 的 per-token 奖励计算完美衔接——实现代码中直接用 log_probs - ref_log_probs 即可。
理由二:惩罚策略逃逸。 Reverse KL 会强烈惩罚当前策略把概率放到参考策略几乎不给概率的区域。如果
理由三:Mode-seeking 的实用价值。 RLHF 的目标是在参考模型附近找到高奖励的回答模式,而不是完整复制参考模型的全部行为。Reverse KL 的 mode-seeking 性质恰好匹配这一需求:它允许策略集中到少数高奖励的回答模式上,放弃参考模型中那些低奖励的回答模式,同时又被 KL 约束限制在合理范围内。打个比方,如果参考模型对某个问题有 10 种可能的回答风格,RLHF 不需要全部保留——只要找到其中最好的 2-3 种并做到极致就够了。
下表将三重理由与 Forward KL 的行为对比:
| 维度 | Reverse KL(RLHF 选择) | Forward KL(假如采用) |
|---|---|---|
| 采样 | 从 | 需要从 |
| 逃逸惩罚 | 强烈惩罚离开 | 惩罚 |
| 模式行为 | 集中到高奖励模式(mode-seeking) | 保留所有模式(mode-covering) |
| 实际效果 | 精准提升 + 不跑飞 | 可能过度分散,无法聚焦高奖励模式 |
表 15-7:RLHF 中 Reverse KL 与假设的 Forward KL 对比。
15.6.5 Token 级 KL 计算的工程实现 [必读]
在 RLHF 的 PPO 训练中,KL 惩罚不是在整个序列级别一次性计算的,而是被分配到每个 token,形成密集的逐步奖励信号。这是理解 RLHF 工程实现的关键细节。
Per-token 奖励的构成。 对于序列中的第
其中
以下代码展示了这一计算逻辑的核心实现:
import torch
def compute_rewards(log_probs, ref_log_probs, reward_score,
prompt_length, kl_coef=0.1, clip_value=5.0):
"""
计算 RLHF 中每个 token 的总奖励。
Args:
log_probs: Actor 模型对 response 每个 token 的 log 概率, [B, seq_len]
ref_log_probs: Reference 模型对 response 每个 token 的 log 概率, [B, seq_len]
reward_score: 奖励模型对整句的打分, [B]
prompt_length: prompt 的长度(response 从此处开始)
kl_coef: KL 惩罚系数 beta
clip_value: 奖励裁剪范围
Returns:
rewards: 每个 token 的总奖励, [B, seq_len]
"""
# ---- Step 1: 计算每个 token 的 KL 惩罚 ----
# kl_t = log pi_theta(a_t|s_t) - log pi_ref(a_t|s_t)
# reward_t = -beta * kl_t = beta * (ref_log_probs - log_probs)
kl_reward = -kl_coef * (log_probs - ref_log_probs)
# ---- Step 2: 裁剪 RM 奖励,防止极端值 ----
reward_clip = torch.clamp(reward_score, -clip_value, clip_value)
# ---- Step 3: 在最后一个 token 处叠加 RM 奖励 ----
rewards = kl_reward.clone()
batch_size = log_probs.shape[0]
start = prompt_length - 1 # response 开始位置
# 找到每个序列的最后一个有效 token(EOS 位置)
# 简化实现:假设 response 占满剩余序列
last_pos = log_probs.shape[1] - 1
for j in range(batch_size):
rewards[j, last_pos] += reward_clip[j]
return rewards
# ---- 使用示例 ----
B, seq_len, prompt_len = 2, 20, 5
log_probs = torch.randn(B, seq_len) * 0.5 - 3.0
ref_log_probs = torch.randn(B, seq_len) * 0.5 - 3.0
reward_score = torch.tensor([0.8, -0.3])
rewards = compute_rewards(log_probs, ref_log_probs, reward_score, prompt_len)
print(f"rewards shape: {rewards.shape}") # [2, 20]
print(f"中间 token 奖励示例: {rewards[0, 10]:.4f}") # 纯 KL 惩罚
print(f"最后 token 奖励示例: {rewards[0, -1]:.4f}") # KL 惩罚 + RM 分数KL 散度的估计方式。 上面的 log_probs - ref_log_probs 实际上是 KL 散度的单样本蒙特卡洛估计。更精确的估计器(Schulman 等人提出)使用:
这个估计器具有更低的方差,在 GRPO 等算法的实现中被广泛采用。
15.6.6 与变分推断的深层联系 [挑战]
Reverse KL 在 RLHF 中的使用并非偶然——它与统计学习中最重要的框架之一变分推断(Variational Inference, VI) 有着深刻的数学联系。理解这一联系,能帮助我们从更高的视角审视 RLHF 的优化目标。
变分推断的核心问题。 在贝叶斯推断中,我们希望计算后验分布
为什么选择 Reverse KL? 变分推断最小化的是:
这是
注意
取负号并翻转优化方向,就得到了著名的 ELBO(Evidence Lower Bound,证据下界):
核心等式为:
由于
与 RLHF 的类比。 这个联系不仅仅是形式上的类似。回顾 RLHF 的带 KL 约束的目标函数,其最优策略的解析解为:
其中
| 变分推断 | RLHF |
|---|---|
| 最大化 ELBO | 最大化 |
表 15-8:变分推断与 RLHF 的结构对应。
就可以发现:RLHF 的优化目标在数学结构上等价于用当前策略
这一联系也解释了 DPO 的理论根基:既然最优策略有解析形式,就可以直接从策略反推奖励、代入 Bradley-Terry 模型,绕开配分函数
15.6.7 KL 系数 的实践考量 [选读]
过小:KL 约束松弛,策略可以大幅偏离参考模型。代理奖励(RM 分数)可能持续上升,但真实质量先升后降——这就是reward hacking(奖励欺骗)。模型可能生成看起来高分但实际无意义的输出。 过大:KL 约束过强,策略几乎等于参考模型,训练形同虚设。模型无法从奖励信号中学到有意义的改进。 - 最优
:需要通过实验调参。InstructGPT 中使用 左右,不同任务和模型规模需要不同的值。
在工程实现中,TRL 库的 PPOTrainer 会记录 objective/kl 指标(当前策略与参考策略的平均 KL 散度),这是监控训练健康度的关键信号。如果 KL 持续增大,说明策略正在"跑飞";如果 KL 几乎不动,说明策略没有有效学习。
15.6.8 小结
本节从 KL 散度的基本定义出发,深入剖析了 Forward KL 与 Reverse KL 的行为差异:
- Forward KL 是 mode-covering 的——它迫使近似分布覆盖真实分布的所有模式,适合经典策略优化(如 TRPO)中"不丢弃已有行为"的需求。
- Reverse KL 是 mode-seeking 的——它允许近似分布只锁定真实分布的某个模式,这恰好匹配 RLHF 中"聚焦高奖励回答、不跑出参考模型范围"的需求。
- RLHF 选择 Reverse KL 有三重理由:采样便利性(从当前策略采样直接可估计)、防止策略逃逸(强惩罚离开参考模型支撑集的行为)、mode-seeking 的实用价值(集中到高奖励模式)。
- 从变分推断的视角看,RLHF 的优化等价于用当前策略去变分逼近最优策略,最大化 RLHF 目标函数等价于最小化
。 - 在工程实现中,KL 惩罚被分解到每个 token 上形成密集奖励信号,与奖励模型的句子级打分互补配合。