Skip to content

15.6 KL 散度深度解析

在前几节中,我们反复看到一个关键约束项出现在 RLHF 的目标函数里:

maxθEx,yπθ[r(x,y)]βDKL(πθ(|x)πref(|x))

这个 KL 散度(Kullback-Leibler Divergence) 惩罚项看似只是一个正则化技巧,实则蕴含着深刻的统计学原理。它的方向选择(谁在前、谁在后)直接决定了优化行为的根本性差异——这正是本节要彻底讲清楚的问题。


15.6.1 KL 散度回顾 [必读]

给定两个概率分布 PQ,KL 散度的定义为:

DKL(PQ)=ExP[logP(x)Q(x)]=xP(x)logP(x)Q(x)

两个关键性质需要牢记:

  1. 非负性DKL(PQ)0,等号成立当且仅当 P=Q
  2. 不对称性DKL(PQ)DKL(QP),因此 KL 散度不是距离度量。

不对称性的直觉可以这样理解:前面的分布负责采样(决定在哪些点上"看")后面的分布负责被比较(决定在这些点上"差多远")DKL(PQ) 是"站在 P 的视角看 Q 离自己多远",而 DKL(QP) 则是"站在 Q 的视角看自己离 P 多远"。


15.6.2 Forward KL vs Reverse KL:不对称性的行为差异 [必读]

假设真实分布 P 是一个双峰分布(两个高斯的混合),我们用一个单峰分布 Q 去逼近它。根据优化方向的不同,Q 会表现出截然相反的行为。

Forward KL:最小化 DKL(PQ) — 模式覆盖(Mode-Covering)

DKL(PQ)=ExP[logP(x)Q(x)]

采样来自 P,这意味着 P 有概率质量的地方都会被考察。如果 Q 在某个 P 的高概率区域给出了极小的概率(即 Q(x)0),则 logP(x)Q(x)+,惩罚极其严厉。因此 Q 被迫覆盖 P 的所有模式——即使这意味着在两个峰之间的低概率区域也要分配不少概率质量。结果是一个宽而平的分布,宁可"过度覆盖",也不敢遗漏任何峰。

Reverse KL:最小化 DKL(QP) — 模式追寻(Mode-Seeking)

DKL(QP)=ExQ[logQ(x)P(x)]

采样来自 Q,这意味着 Q 没有覆盖到的区域根本不会被采样,因此不产生任何惩罚。但如果 Q 把大量概率放到了 P 几乎为零的区域(即 P(x)0),则 logQ(x)P(x)+,惩罚同样极其严厉。因此 Q 倾向于紧紧锁定 P 的某一个峰,完全忽略其他峰的存在。结果是一个窄而尖的分布,集中火力在一个模式上。

下表总结了这一核心对比:

特征Forward KL:DKL(P|Q)Reverse KL:DKL(Q|P)
采样来源P(真实分布)Q(近似分布)
严厉惩罚QP 高概率处给低概率QP 低概率处给高概率
行为倾向覆盖 P 的所有模式锁定 P 的某个模式
别名Mode-covering(模式覆盖)Mode-seeking(模式追寻)
拟合结果宽而平(可能过于分散)窄而尖(可能遗漏模式)

表 15-6:Forward KL 与 Reverse KL 的行为对比。

以下代码可视化了这一关键差异,直观展示单峰高斯去拟合双峰混合分布时的不同行为:

python
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from scipy.optimize import minimize

# 真实分布 P:双峰混合高斯
def p_pdf(x):
    return 0.5 * norm.pdf(x, loc=-2, scale=0.8) + 0.5 * norm.pdf(x, loc=2, scale=0.8)

# 近似分布 Q:单峰高斯,用 (mu, sigma) 参数化
def q_pdf(x, mu, sigma):
    return norm.pdf(x, loc=mu, scale=sigma)

# Forward KL: D_KL(P || Q) — 用数值积分近似
def forward_kl(params):
    mu, log_sigma = params
    sigma = np.exp(log_sigma)
    x = np.linspace(-6, 6, 2000)
    p = p_pdf(x)
    q = q_pdf(x, mu, sigma) + 1e-10
    # 只在 p > 0 的地方计算
    mask = p > 1e-10
    return np.trapz(p[mask] * np.log(p[mask] / q[mask]), x[mask])

# Reverse KL: D_KL(Q || P) — 用数值积分近似
def reverse_kl(params):
    mu, log_sigma = params
    sigma = np.exp(log_sigma)
    x = np.linspace(-6, 6, 2000)
    p = p_pdf(x) + 1e-10
    q = q_pdf(x, mu, sigma)
    mask = q > 1e-10
    return np.trapz(q[mask] * np.log(q[mask] / p[mask]), x[mask])

# 优化
res_fwd = minimize(forward_kl, x0=[0.0, np.log(1.0)], method='Nelder-Mead')
res_rev = minimize(reverse_kl, x0=[1.5, np.log(0.5)], method='Nelder-Mead')

mu_fwd, sigma_fwd = res_fwd.x[0], np.exp(res_fwd.x[1])
mu_rev, sigma_rev = res_rev.x[0], np.exp(res_rev.x[1])

# 绘图
x = np.linspace(-6, 6, 500)
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))

for ax, mu, sigma, title, color in [
    (axes[0], mu_fwd, sigma_fwd, f'Forward KL (mode-covering)\n$\\mu$={mu_fwd:.2f}, $\\sigma$={sigma_fwd:.2f}', 'royalblue'),
    (axes[1], mu_rev, sigma_rev, f'Reverse KL (mode-seeking)\n$\\mu$={mu_rev:.2f}, $\\sigma$={sigma_rev:.2f}', 'crimson'),
]:
    ax.fill_between(x, p_pdf(x), alpha=0.3, color='gray', label='P (true)')
    ax.plot(x, q_pdf(x, mu, sigma), linewidth=2, color=color, label='Q (approx)')
    ax.set_title(title, fontsize=13)
    ax.legend(fontsize=11)
    ax.set_xlabel('x')
    ax.set_ylabel('density')

plt.tight_layout()
plt.savefig('forward_vs_reverse_kl.png', dpi=150)
plt.show()

运行上述代码可以清楚看到:Forward KL 的结果是一个均值接近 0、方差很大的宽高斯,"铺开"覆盖两个峰;而 Reverse KL 的结果是一个紧紧贴住右峰(或左峰,取决于初始值)的窄高斯,完全忽略了另一个峰。


15.6.3 Forward KL 在经典强化学习中的应用 [选读]

经典策略优化算法 TRPO(Trust Region Policy Optimization) 使用的 KL 约束是:

Esdπold[DKL(πold(|s)πθ(|s))]δ

注意这里是 πold 在前、πθ 在后,属于 Forward KL

为什么 TRPO 选择 Forward KL? 原因在于 TRPO 的更新建立在旧策略的样本之上。训练数据来自 πold,站在旧策略的采样分布上估计 KL 散度是最自然的选择——这保证了在旧策略有概率质量的所有区域,新策略都不会偏移太远。Forward KL 的 mode-covering 性质在这里是合理的:我们不希望新策略"丢弃"旧策略已经学到的任何行为。


15.6.4 Reverse KL 在 RLHF 中的三重理由 [必读]

在 RLHF 中,KL 惩罚项的形式为:

βDKL(πθ(|x)πref(|x))

这里是 πθ(当前策略)在前、πref(参考策略)在后,属于 Reverse KL。RLHF 选择这个方向,背后有三重相互强化的理由。

理由一:采样便利性。 训练时的序列是从当前策略 πθ 采样得到的(on-policy 生成),因此我们可以天然地估计以下期望:

DKL(πθπref)=Eyπθ[logπθ(y|x)logπref(y|x)]

这个期望可以直接用当前策略的采样来做蒙特卡洛估计,不需要额外的重要性采样修正。在 token 级别,每个 token 的 KL 贡献就是:

klt=logπθ(atst)logπref(atst)

这与 PPO 的 per-token 奖励计算完美衔接——实现代码中直接用 log_probs - ref_log_probs 即可。

理由二:惩罚策略逃逸。 Reverse KL 会强烈惩罚当前策略把概率放到参考策略几乎不给概率的区域。如果 πθ(y|x) 很大但 πref(y|x)0,则 logπθ(y|x)πref(y|x)+,产生巨大的惩罚。这正是我们在 RLHF 中需要的行为——允许策略在参考模型的分布范围内改进,但绝不允许"跑飞"到参考模型从未见过的荒诞输出。这有效防止了 reward hacking(奖励欺骗):模型不能通过生成参考模型从未产生过的"神秘咒语"来骗取高奖励。

理由三:Mode-seeking 的实用价值。 RLHF 的目标是在参考模型附近找到高奖励的回答模式,而不是完整复制参考模型的全部行为。Reverse KL 的 mode-seeking 性质恰好匹配这一需求:它允许策略集中到少数高奖励的回答模式上,放弃参考模型中那些低奖励的回答模式,同时又被 KL 约束限制在合理范围内。打个比方,如果参考模型对某个问题有 10 种可能的回答风格,RLHF 不需要全部保留——只要找到其中最好的 2-3 种并做到极致就够了。

下表将三重理由与 Forward KL 的行为对比:

维度Reverse KL(RLHF 选择)Forward KL(假如采用)
采样πθ 采样,直接可估计需要从 πref 采样或重要性采样
逃逸惩罚强烈惩罚离开 πref 的支撑集惩罚 πθπref 高概率处给低概率
模式行为集中到高奖励模式(mode-seeking)保留所有模式(mode-covering)
实际效果精准提升 + 不跑飞可能过度分散,无法聚焦高奖励模式

表 15-7:RLHF 中 Reverse KL 与假设的 Forward KL 对比。


15.6.5 Token 级 KL 计算的工程实现 [必读]

在 RLHF 的 PPO 训练中,KL 惩罚不是在整个序列级别一次性计算的,而是被分配到每个 token,形成密集的逐步奖励信号。这是理解 RLHF 工程实现的关键细节。

Per-token 奖励的构成。 对于序列中的第 t 个 token,其获得的总奖励为:

Rt=β(logπθ(atst)logπref(atst))KL 惩罚(每个 token 都有)+rscoreI(t=T)RM 奖励(仅最后一个 token)

其中 I(t=T) 是指示函数,表示奖励模型的打分只在序列末尾出现。这意味着中间 token 的奖励完全由 KL 惩罚决定——如果某个 token 的生成概率远高于参考模型给出的概率,该 token 就会收到一个负的即时奖励。

以下代码展示了这一计算逻辑的核心实现:

python
import torch

def compute_rewards(log_probs, ref_log_probs, reward_score,
                    prompt_length, kl_coef=0.1, clip_value=5.0):
    """
    计算 RLHF 中每个 token 的总奖励。

    Args:
        log_probs:      Actor 模型对 response 每个 token 的 log 概率, [B, seq_len]
        ref_log_probs:  Reference 模型对 response 每个 token 的 log 概率, [B, seq_len]
        reward_score:   奖励模型对整句的打分, [B]
        prompt_length:  prompt 的长度(response 从此处开始)
        kl_coef:        KL 惩罚系数 beta
        clip_value:     奖励裁剪范围

    Returns:
        rewards: 每个 token 的总奖励, [B, seq_len]
    """
    # ---- Step 1: 计算每个 token 的 KL 惩罚 ----
    # kl_t = log pi_theta(a_t|s_t) - log pi_ref(a_t|s_t)
    # reward_t = -beta * kl_t = beta * (ref_log_probs - log_probs)
    kl_reward = -kl_coef * (log_probs - ref_log_probs)

    # ---- Step 2: 裁剪 RM 奖励,防止极端值 ----
    reward_clip = torch.clamp(reward_score, -clip_value, clip_value)

    # ---- Step 3: 在最后一个 token 处叠加 RM 奖励 ----
    rewards = kl_reward.clone()
    batch_size = log_probs.shape[0]
    start = prompt_length - 1  # response 开始位置

    # 找到每个序列的最后一个有效 token(EOS 位置)
    # 简化实现:假设 response 占满剩余序列
    last_pos = log_probs.shape[1] - 1
    for j in range(batch_size):
        rewards[j, last_pos] += reward_clip[j]

    return rewards

# ---- 使用示例 ----
B, seq_len, prompt_len = 2, 20, 5
log_probs = torch.randn(B, seq_len) * 0.5 - 3.0
ref_log_probs = torch.randn(B, seq_len) * 0.5 - 3.0
reward_score = torch.tensor([0.8, -0.3])

rewards = compute_rewards(log_probs, ref_log_probs, reward_score, prompt_len)
print(f"rewards shape: {rewards.shape}")         # [2, 20]
print(f"中间 token 奖励示例: {rewards[0, 10]:.4f}")  # 纯 KL 惩罚
print(f"最后 token 奖励示例: {rewards[0, -1]:.4f}")  # KL 惩罚 + RM 分数

KL 散度的估计方式。 上面的 log_probs - ref_log_probs 实际上是 KL 散度的单样本蒙特卡洛估计。更精确的估计器(Schulman 等人提出)使用:

D^KL=elogπreflogπθ(logπreflogπθ)1

这个估计器具有更低的方差,在 GRPO 等算法的实现中被广泛采用。


15.6.6 与变分推断的深层联系 [挑战]

Reverse KL 在 RLHF 中的使用并非偶然——它与统计学习中最重要的框架之一变分推断(Variational Inference, VI) 有着深刻的数学联系。理解这一联系,能帮助我们从更高的视角审视 RLHF 的优化目标。

变分推断的核心问题。 在贝叶斯推断中,我们希望计算后验分布 p(zx)=p(x,z)p(x),但分母 p(x)=p(x,z)dz(边缘似然/证据)通常是一个难以计算的高维积分。变分推断的策略是:找一个来自简单分布族的分布 qϕ(z) 来近似 p(zx)

为什么选择 Reverse KL? 变分推断最小化的是:

DKL(qϕ(z)p(zx))=Eqϕ(z)[logqϕ(z)p(zx)]

这是 q 在前、p 在后——正是 Reverse KL。关键原因是可以绕开难以计算的归一化常数。展开这个 KL 散度:

DKL(qp(z|x))=Eq(z)[logq(z)logp(x,z)]+logp(x)

注意 logp(x) 对于优化 q 来说是常数!因此,最小化 Reverse KL 等价于最小化:

Eq(z)[logq(z)logp(x,z)]

取负号并翻转优化方向,就得到了著名的 ELBO(Evidence Lower Bound,证据下界)

ELBO(q)=Eq(z)[logp(x,z)logq(z)]

核心等式为:

logp(x)=ELBO(q)+DKL(q(z)p(z|x))

由于 DKL0,ELBO 是 logp(x) 的下界。最大化 ELBO 等价于最小化 Reverse KL——这就是 VAE(变分自编码器)等模型的理论根基。

与 RLHF 的类比。 这个联系不仅仅是形式上的类似。回顾 RLHF 的带 KL 约束的目标函数,其最优策略的解析解为:

π(yx)=1Z(x)πref(yx)exp(r(x,y)β)

其中 Z(x)=yπref(yx)exp(r(x,y)β) 是配分函数。如果我们做如下对应:

变分推断RLHF
qϕ(z)(近似分布)πθ(yx)(当前策略)
p(zx)(真实后验)π(yx)(最优策略)
logp(x,z)(联合对数似然)r(x,y)/β+logπref(yx)
最大化 ELBO最大化 Eπθ[r(x,y)]βDKL(πθ|πref)

表 15-8:变分推断与 RLHF 的结构对应。

就可以发现:RLHF 的优化目标在数学结构上等价于用当前策略 πθ 去变分逼近最优策略 π。最小化 Reverse KL DKL(πθπ) 等价于最大化 RLHF 目标函数——二者本质上是同一个优化问题的不同表述。

这一联系也解释了 DPO 的理论根基:既然最优策略有解析形式,就可以直接从策略反推奖励、代入 Bradley-Terry 模型,绕开配分函数 Z(x)(在偏好概率中相消),得到不需要显式奖励模型的 DPO 损失函数。


15.6.7 KL 系数 β 的实践考量 [选读]

β 是 RLHF 中最重要的超参数之一,它控制着探索奖励保持稳定之间的权衡:

  • β 过小:KL 约束松弛,策略可以大幅偏离参考模型。代理奖励(RM 分数)可能持续上升,但真实质量先升后降——这就是reward hacking(奖励欺骗)。模型可能生成看起来高分但实际无意义的输出。
  • β 过大:KL 约束过强,策略几乎等于参考模型,训练形同虚设。模型无法从奖励信号中学到有意义的改进。
  • 最优 β:需要通过实验调参。InstructGPT 中使用 β=0.1 左右,不同任务和模型规模需要不同的值。

在工程实现中,TRL 库的 PPOTrainer 会记录 objective/kl 指标(当前策略与参考策略的平均 KL 散度),这是监控训练健康度的关键信号。如果 KL 持续增大,说明策略正在"跑飞";如果 KL 几乎不动,说明策略没有有效学习。


15.6.8 小结

本节从 KL 散度的基本定义出发,深入剖析了 Forward KL 与 Reverse KL 的行为差异:

  • Forward KL 是 mode-covering 的——它迫使近似分布覆盖真实分布的所有模式,适合经典策略优化(如 TRPO)中"不丢弃已有行为"的需求。
  • Reverse KL 是 mode-seeking 的——它允许近似分布只锁定真实分布的某个模式,这恰好匹配 RLHF 中"聚焦高奖励回答、不跑出参考模型范围"的需求。
  • RLHF 选择 Reverse KL 有三重理由:采样便利性(从当前策略采样直接可估计)、防止策略逃逸(强惩罚离开参考模型支撑集的行为)、mode-seeking 的实用价值(集中到高奖励模式)。
  • 变分推断的视角看,RLHF 的优化等价于用当前策略去变分逼近最优策略,最大化 RLHF 目标函数等价于最小化 DKL(πθπ)
  • 在工程实现中,KL 惩罚被分解到每个 token 上形成密集奖励信号,与奖励模型的句子级打分互补配合。