Skip to content

16.2 PPO(近端策略优化)

在 15.2 节中,我们从 RLHF 全流程的视角介绍了 PPO 的基本角色——它是连接奖励模型与策略优化的"最后一公里"。但当时对 PPO 算法本身的讨论仅停留在公式层面。本节将深入 PPO 的内部机理:从 TRPO 到 PPO 的演进动机、裁剪机制的分类讨论、Actor-Critic 在 LLM 场景下的工程实现,以及一个从零构建的完整 PPO 训练循环。


16.2.1 从 TRPO 到 PPO:信赖域的简化

策略梯度的根本矛盾。 朴素的策略梯度算法(如 REINFORCE)有一个致命缺陷——它是纯在策略(On-Policy) 的:每次采样只能更新一次参数,然后必须丢弃数据重新采样。这导致了极低的样本效率。

为了复用已采集的数据,TRPO(Trust Region Policy Optimization) 引入了重要性采样(Importance Sampling),允许用旧策略 πθold 采集的数据去更新新策略 πθ。其目标函数为:

LCPI(θ)=Et[πθ(at|st)πθold(at|st)A^t]=Et[rt(θ)A^t]

其中 rt(θ)=πθ(at|st)/πθold(at|st)概率比(Probability Ratio)A^t 是优势函数。

但直接最大化 LCPI 会导致灾难性更新——当 rt 偏离 1 过远时,重要性采样的方差会爆炸。TRPO 的解决方案是添加硬约束

maxθLCPI(θ)s.t.DKL(πθoldπθ)δ

这需要计算 KL 散度的二阶近似(Fisher 信息矩阵),并用共轭梯度法求解,实现复杂且计算开销大。

PPO 的核心洞察。 Schulman 等人在 2017 年提出了一个优雅的替代方案:用裁剪(Clipping)代替 KL 约束。与其费力计算二阶信息来限制 KL 散度,不如直接对概率比 rt 施加硬边界。这就是 PPO-Clip:

LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

ϵ 通常取 0.1 或 0.2。这个看似简单的公式,实际上蕴含了精巧的分类讨论逻辑。


16.2.2 裁剪机制的深度剖析

PPO-Clip 的精妙之处在于:只在"有利方向"上限制更新幅度,在"纠错方向"上不设上限。让我们逐一分析所有情形。

PPO-Clip 裁剪机制:正优势(左)限制概率增长,负优势(右)限制概率下降,纠错方向不受限

图 16-1:PPO-Clip 的核心机制。左图:好动作(A>0)概率增长被截断在 1+epsilon,防止过度乐观;右图:坏动作(A<0)概率下降被截断在 1-epsilon,防止矫枉过正。纠错方向(好动作概率降低了、坏动作概率升高了)不受限制。

情形 1:A^t>0(好动作,应增大概率)。

  • rt>1+ϵ:概率已经增长过多。clip(rt)=1+ϵ,取 min 后使用裁剪项,梯度被截断,阻止继续增大
  • rt<1ϵ:好动作的概率反而降低了,这是"错误",不应限制修复。取 min 后使用原始项 rtA^t(值更小),梯度不受限,自由恢复

情形 2:A^t<0(坏动作,应降低概率)。

  • rt<1ϵ:概率已经下降过多。clip(rt)=1ϵ,此时 clip(rt)A^t>rtA^t(两个负数比较),取 min 使用裁剪项,梯度被截断
  • rt>1+ϵ:坏动作的概率反而升高了,这是"错误"。取 min 使用原始项(值更小即更负),梯度不受限,自由修正

核心设计哲学: PPO 采取了一种"悲观更新"策略——对策略改进方向设置天花板(防止过度优化),对纠错方向则放开限制(允许快速修复错误)。

下面用代码验证上述分析:

python
import torch

def ppo_clip_loss(log_probs, old_log_probs, advantages, epsilon=0.2):
    """
    PPO-Clip 损失函数。

    Args:
        log_probs: 当前策略的 log 概率, shape [batch, seq_len]
        old_log_probs: 旧策略的 log 概率, shape [batch, seq_len]
        advantages: 优势函数值, shape [batch, seq_len]
        epsilon: 裁剪范围

    Returns:
        标量损失值(取负号,因为 PyTorch 最小化 loss)
    """
    # 计算概率比: r_t = pi_new / pi_old = exp(log_new - log_old)
    ratio = torch.exp(log_probs - old_log_probs)

    # 未裁剪的代理目标
    surr1 = ratio * advantages
    # 裁剪后的代理目标
    surr2 = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * advantages

    # 取 min 实现"悲观更新",取负号转为最小化问题
    policy_loss = -torch.min(surr1, surr2).mean()
    return policy_loss

# 验证: 好动作 + ratio 过大 -> 应被裁剪
adv_pos = torch.tensor([1.0])
ratio_high = torch.tensor([1.5])  # 超过 1+0.2
log_new = torch.log(ratio_high)
log_old = torch.zeros(1)
loss = ppo_clip_loss(log_new, log_old, adv_pos)
print(f"好动作 ratio=1.5: loss={loss.item():.3f}")
# 裁剪后 min(1.5*1, 1.2*1) = 1.2, loss = -1.2

16.2.3 四模型架构与工程实现

PPO 在 LLM 对齐中的实现需要同时维护四个模型,这也是其工程复杂度远高于 DPO 的根源。

PPO 训练数据流:从 batch prompts 出发,Actor 生成回答,四个模型协同计算 experience,最终更新 Actor 和 Critic

图 16-2:RLHF-PPO 的数据流。Actor 模型根据 prompt 生成回答,Reference 模型和 Reward 模型分别提供 KL 约束和奖励信号,Critic 模型估计状态价值,四者协同计算出 experience 后进行多轮策略更新。

模型符号初始化来源训练状态输出
Actor(策略模型)πθSFT 模型训练中给定 prompt,生成 response
Critic(价值模型)VϕSFT 或 RM 模型训练中估计每个 token 位置的状态价值
Reference(参考模型)πrefSFT 模型冻结提供 KL 约束基准
Reward(奖励模型)rψ偏好数据训练冻结对完整回答打分

Critic 的特殊设计。 Critic 模型的结构与 Actor 几乎相同,唯一区别是输出头:Actor 使用 lm_head(输出词表大小的 logits),Critic 则将其替换为一个输出标量的线性层 nn.Linear(hidden_size, 1),用于预测每个 token 位置的状态价值。

python
import torch
import torch.nn as nn

class CriticModel(nn.Module):
    """
    Critic 模型:基于语言模型骨架,替换输出头为标量价值预测。
    """
    def __init__(self, base_model, hidden_size):
        super().__init__()
        self.backbone = base_model  # Transformer 骨架(共享 Actor 的结构)
        # 替换 lm_head 为标量输出
        self.value_head = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, input_ids, attention_mask=None):
        # 获取最后一层隐状态
        hidden_states = self.backbone(
            input_ids, attention_mask=attention_mask
        ).last_hidden_state  # [batch, seq_len, hidden_size]
        # 映射为标量价值
        values = self.value_head(hidden_states).squeeze(-1)  # [batch, seq_len]
        return values

Reference 模型的作用。 Reference 模型是 SFT 模型的一份冻结副本,它的唯一职责是提供 logπref(at|st),用于计算 KL 散度惩罚。没有它,Actor 可能会为了追求高奖励而生成高分但无意义的输出(Reward Hacking)。

显存挑战。 以 7B 参数模型为例,四个模型在 FP16 下需要约 56 GB 显存,这还不包括激活值、梯度和优化器状态。常用的缓解策略包括:

  • DeepSpeed ZeRO-3:将参数分片到多卡,按需聚合
  • LoRA:仅训练 Actor 和 Critic 的低秩适配器,大幅减少可训练参数
  • 共享骨架:Actor 和 Critic 共享 Transformer 骨架,仅各自维护独立的输出头

16.2.4 奖励组装与优势计算

在 RLHF 中,PPO 使用的奖励信号并非直接来自 Reward Model,而是经过精心组装的复合信号。

逐 token 奖励的组装公式。 对于生成序列中的第 t 个 token,其奖励定义为:

Rt=β(logπθ(at|st)logπref(at|st))KL 惩罚(稠密,每个 token)+rscoreI(t=T)RM 打分(稀疏,仅末尾)
  • KL 惩罚是逐 token 的稠密信号,约束每一步都不偏离参考模型太远
  • RM 打分是句子级别的稀疏信号,仅在序列末尾给出整体评价

这种"稠密 KL + 稀疏 RM"的设计,既提供了逐步的方向引导(避免 Reward Hacking),又在全局层面追求高质量输出。

python
import torch

def compute_rewards(actor_log_probs, ref_log_probs, rm_scores,
                    action_mask, kl_coef=0.1):
    """
    组装 PPO 训练用的逐 token 奖励。

    Args:
        actor_log_probs: Actor 的 log 概率, [batch, seq_len]
        ref_log_probs: Reference 的 log 概率, [batch, seq_len]
        rm_scores: Reward Model 打分, [batch] (标量)
        action_mask: response 部分的 mask, [batch, seq_len]
        kl_coef: KL 惩罚系数 beta

    Returns:
        rewards: 逐 token 的奖励, [batch, seq_len]
    """
    # 1) 稠密 KL 惩罚: 每个 token 都有
    kl_penalty = -kl_coef * (actor_log_probs - ref_log_probs)
    rewards = kl_penalty

    # 2) 找到每个序列的最后一个有效 token 位置
    eos_indices = (action_mask.sum(dim=1) - 1).long()  # [batch]

    # 3) 在 EOS 位置叠加 RM 打分(稀疏信号)
    for i in range(rewards.size(0)):
        rewards[i, eos_indices[i]] += rm_scores[i]

    return rewards

GAE 优势计算。 有了逐 token 的奖励后,下一步是利用 GAE 计算优势函数(15.1 节已推导公式,这里聚焦实现)。在 LLM 场景中,GAE 的计算仅在 response 部分进行,prompt 部分不参与:

python
import torch

def compute_gae_for_llm(rewards, values, action_mask,
                        gamma=1.0, lam=0.95):
    """
    为 LLM 的 PPO 训练计算 GAE 优势和回报值。
    仅计算 response 部分(由 action_mask 标记)。

    Args:
        rewards: 逐 token 奖励, [batch, seq_len]
        values: Critic 预测的价值, [batch, seq_len]
        action_mask: response 部分 mask, [batch, seq_len]
        gamma: 折扣因子(LLM 中通常取 1.0)
        lam: GAE lambda

    Returns:
        advantages: 优势值, [batch, seq_len]
        returns: 回报值(用于训练 Critic), [batch, seq_len]
    """
    batch_size, seq_len = rewards.shape
    advantages = torch.zeros_like(rewards)

    for b in range(batch_size):
        last_gae = 0.0
        # 从后往前递归计算
        for t in reversed(range(seq_len)):
            if action_mask[b, t] == 0:
                continue
            next_value = values[b, t + 1] if t + 1 < seq_len else 0.0
            delta = rewards[b, t] + gamma * next_value - values[b, t]
            last_gae = delta + gamma * lam * last_gae
            advantages[b, t] = last_gae

    returns = advantages + values
    return advantages.detach(), returns

LLM 的 PPO 训练中 γ 通常取 1.0(不打折),因为我们关注的是整个回答的最终质量,不需要对未来奖励做衰减。而 λ 仍取 0.95 左右以平衡偏差-方差。


16.2.5 完整 PPO 训练循环

将上述所有组件组装在一起,就构成了一轮完整的 PPO 迭代。以下代码展示了核心训练逻辑(简化自 MiniMind 的 train_ppo.py 实现):

PPO 训练流程中四个模型的协同关系

图 16-3:InstructGPT 三阶段流程。PPO 训练阶段(第三阶段)需要 Actor、Critic、Reward、Reference 四个模型协同工作。

python
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
import torch
import torch.nn.functional as F

def ppo_train_step(actor, critic, ref_model, reward_model,
                   prompts, tokenizer,
                   clip_epsilon=0.2, vf_coef=0.5, kl_coef=0.02,
                   ppo_epochs=4):
    """
    PPO 单步训练(一个 batch)。

    Args:
        actor: 策略模型(可训练)
        critic: 价值模型(可训练)
        ref_model: 参考模型(冻结)
        reward_model: 奖励模型(冻结)
        prompts: 输入 prompt 的 token ids, [batch, prompt_len]
        tokenizer: 分词器
        clip_epsilon: PPO 裁剪范围
        vf_coef: 价值损失权重
        kl_coef: KL 惩罚系数
        ppo_epochs: 每批数据的训练轮数
    """
    actor.eval()

    # ===== Phase 1: 采样(Rollout)=====
    with torch.no_grad():
        # Actor 生成回答
        generated = actor.generate(prompts, max_new_tokens=256,
                                   do_sample=True, temperature=0.7)
        response_ids = generated[:, prompts.size(1):]

        # 计算 old_log_probs(作为后续 ratio 的分母)
        full_ids = generated
        actor_output = actor(full_ids)
        old_log_probs = get_per_token_log_probs(
            actor_output.logits, full_ids
        ).detach()

        # Reference 的 log_probs
        ref_output = ref_model(full_ids)
        ref_log_probs = get_per_token_log_probs(
            ref_output.logits, full_ids
        )

        # Reward Model 打分
        rm_scores = reward_model.score(full_ids)  # [batch]

        # Critic 估值
        values = critic(full_ids).detach()  # [batch, seq_len]

    # ===== Phase 2: 计算优势 =====
    action_mask = build_response_mask(full_ids, prompts.size(1))
    rewards = compute_rewards(old_log_probs, ref_log_probs,
                              rm_scores, action_mask, kl_coef)
    advantages, returns = compute_gae_for_llm(
        rewards, values, action_mask, gamma=1.0, lam=0.95
    )
    # 优势标准化
    adv_mean = advantages[action_mask.bool()].mean()
    adv_std = advantages[action_mask.bool()].std() + 1e-8
    advantages = (advantages - adv_mean) / adv_std

    # ===== Phase 3: 多轮更新 =====
    actor.train()
    for epoch in range(ppo_epochs):
        # 当前策略的 log_probs
        new_output = actor(full_ids)
        new_log_probs = get_per_token_log_probs(
            new_output.logits, full_ids
        )

        # Critic 当前估值
        new_values = critic(full_ids)

        # --- Policy Loss (PPO-Clip) ---
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - clip_epsilon,
                            1 + clip_epsilon) * advantages
        policy_loss = -torch.min(surr1, surr2)
        policy_loss = (policy_loss * action_mask).sum() \
                      / action_mask.sum()

        # --- Value Loss (MSE with optional clip) ---
        value_loss = F.mse_loss(
            new_values * action_mask, returns * action_mask
        )

        # --- 总损失 ---
        loss = policy_loss + vf_coef * value_loss
        loss.backward()
        # ... optimizer.step(), scheduler.step() ...

        # 早停: 若 KL 过大则退出内循环
        approx_kl = (old_log_probs - new_log_probs).mean()
        if approx_kl > 0.02:
            break


def get_per_token_log_probs(logits, input_ids):
    """提取每个 token 对应的 log 概率。"""
    log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
    labels = input_ids[:, 1:]
    per_token = torch.gather(
        log_probs, dim=-1, index=labels.unsqueeze(-1)
    ).squeeze(-1)
    return per_token


def build_response_mask(full_ids, prompt_len):
    """构建 response 部分的 mask。"""
    mask = torch.zeros_like(full_ids, dtype=torch.float)
    mask[:, prompt_len:] = 1.0
    return mask

关键实现细节解读:

  1. 采样用当前策略,分母固定不变。 标准 PPO 实践中,先用 Actor 生成数据,立刻计算一遍 log_probs 并 detach 作为 old_log_probs。这保证了 ratio 初始值为 1,后续多轮更新中 ratio 逐渐偏离 1,clip 机制限制偏离幅度
  2. 优势标准化。 在计算完 GAE 后,对整个 batch 的优势值做 z-score 标准化(减均值除标准差),可以稳定训练
  3. 多轮更新(PPO Epochs)。 一批采样数据通常被复用 3-5 轮,每轮重新计算 ratio 和梯度。KL 早停机制防止复用过多导致策略偏移
  4. 仅计算 response 部分。 通过 action_mask 确保 loss 只作用在 response token 上,prompt 部分不参与梯度计算

16.2.6 Value Loss 的裁剪

与 Policy Loss 类似,Critic 的 Value Loss 也可以引入裁剪来稳定训练。这一技术被称为 Clipped Value Loss

LCLIPVF=12max[(Vθ(st)Rt)2,(Vθclip(st)Rt)2]

其中 Vθclip(st)=Vθold(st)+clip(Vθ(st)Vθold(st),ϵ,ϵ)

这里取 max 而非 min,与 Policy Loss 的逻辑恰好相反——Policy Loss 目标是最大化奖励(取 min 做悲观估计),而 Value Loss 目标是最小化误差(取 max 做保守估计)。两者的哲学一致:限制参数更新幅度,防止单次更新过大导致训练不稳定

python
def clipped_value_loss(values, old_values, returns, clip_eps=0.2):
    """带裁剪的 Critic 损失函数。"""
    # 将新 value 限制在旧 value 的 ±epsilon 邻域内
    values_clipped = old_values + torch.clamp(
        values - old_values, -clip_eps, clip_eps
    )
    # 两个候选 loss
    loss_unclipped = (values - returns) ** 2
    loss_clipped = (values_clipped - returns) ** 2
    # 取较大者: 即使 clip 后更准也不奖励,防止 value 更新过快
    loss = 0.5 * torch.max(loss_unclipped, loss_clipped).mean()
    return loss

16.2.7 KL 散度的蒙特卡洛估计

在 PPO-RLHF 中,KL 散度惩罚是防止 Reward Hacking 的核心机制。然而精确计算 DKL(πθπref) 需要遍历整个词表(数万到十万个 token),这在实践中不可行。

蒙特卡洛近似。 由于 KL 散度的定义本身就是一个期望:

DKL(πθπref)=Eaπθ[logπθ(a|s)πref(a|s)]=Eaπθ[logπθ(a|s)logπref(a|s)]

而生成的 token 正好是从 πθ 中采样的,因此:

DKL1Tt=1T(logπθ(at|st)logπref(at|st))

这就是代码中 (actor_log_probs - ref_log_probs).mean() 的数学基础——它是 KL 散度的无偏估计量(Unbiased Estimator),无需遍历词表,计算量极小。

此外,trl 等工业级框架通常采用更精确的 Schulman KL 估计器

D^KL=exp(logπreflogπθ)(logπreflogπθ)1

该估计器的方差更低,且始终非负。


16.2.8 使用 trl PPOTrainer

HuggingFace 的 trl 库提供了生产级的 PPO 训练器,封装了上述所有细节。以下是一个典型的使用示例:

bash
# 安装: pip install trl
# 启动 PPO 训练
accelerate launch --config_file deepspeed_zero2.yaml \
    ppo_train.py \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --sft_model_path <your_sft_model> \
    --reward_model_path <your_reward_model> \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 4 \
    --total_episodes 100000 \
    --num_ppo_epochs 4 \
    --num_mini_batches 1 \
    --missing_eos_penalty 1.0 \
    --stop_token eos

关键超参数解读:

参数默认值含义
--num_ppo_epochs4每批数据的 PPO 内循环轮数
--cliprange0.2裁剪范围 ϵ
--missing_eos_penalty1.0未生成 EOS 的惩罚(引导完整回答)
--learning_rate3e-6学习率(通常比 SFT 小一个数量级)
--total_episodes10000总训练 episode 数

训练监控。 PPO 训练需要关注以下指标判断是否健康:

  • objective/rlhf_reward:最终 RLHF 奖励(RM 分数 - KL 惩罚),应持续上升
  • val/ratio:概率比 rt 的均值,应浮动在 1.0 附近。若飙升到 2.0+,说明策略更新过猛
  • policy/clipfrac_avg:被裁剪的更新比例,过高意味着策略变化过大
  • objective/kl:策略与参考模型的 KL 散度,应缓慢增长,不能爆炸

PPO 训练在 MiniMind 上的实验结果

图 16-4:PPO 训练后模型的生成效果对比。经过 PPO 训练的模型能生成更符合人类偏好的回答。


16.2.9 On-Policy 与工程权衡

PPO 是严格的 On-Policy 算法——训练数据必须由当前策略生成。这带来了一个根本性的工程挑战:采样与训练必须交替进行,无法像 SFT 那样预先准备好全部数据。

标准流程 vs 滚动更新。 严格的 PPO 实现遵循以下流程:

  1. 用当前 Actor 生成一批数据
  2. 立即计算 log_probs 并 detach 为 old_log_probs(保证 ratio 初始为 1)
  3. 在这批数据上训练 K 个 epoch,期间 old_log_probs 不变
  4. 回到步骤 1,重新采样

而部分轻量级实现(如 MiniMind)采用了滚动更新(Rolling Update) 方式:维护一个 old_actor_model,每隔若干步同步一次 Actor 的参数。这种做法的 ratio 初始值可能不严格等于 1,但 clip 机制能兜底处理偏差,且实现更简单。

采样效率的关键数字。 在实际训练中,一个 7B 模型的 PPO 迭代中,采样(生成回答)通常占据 60-70% 的时间,训练只占 30-40%。因此工业级框架(如 DeepSpeed-Chat、veRL)会将采样和训练分配到不同的 GPU 组上并行执行,以提高整体吞吐。


16.2.10 小结

PPO 通过一个看似简单的裁剪操作,优雅地解决了策略优化中"步子太大容易扯着"的核心问题。其完整实现涉及四个模型的协同管理、复合奖励信号的组装、GAE 优势估计、以及多轮策略更新与早停机制。尽管工程复杂度较高(四个模型同时驻留 GPU、采样与训练交替进行),PPO 仍然是当前 RLHF 流程中最经典、最广泛验证的算法。理解 PPO 的每一个组件,是掌握后续 GRPO(去掉 Critic)和 DPO(去掉 RM + RL)等简化方案的前提——只有知道"什么都要"时每个部件的作用,才能判断"去掉哪个"是安全的。