Skip to content

第 13 章:知识蒸馏

本章导读

知识蒸馏(Knowledge Distillation, KD)是将大模型的能力迁移到小模型的核心技术。当 DeepSeek-R1(671B)、Qwen3-235B 等超大模型展现出强大的推理能力时,如何让 1.5B~32B 的轻量模型继承这些能力,就是蒸馏要解决的问题。

本章围绕四个层次展开:

  1. 基础原理——蒸馏的优化目标、三种范式及其差异;
  2. 在线蒸馏——分布偏移问题的根因,GOLD 框架如何实现跨 Tokenizer 的序列级蒸馏,以及 Qwen3 的两阶段工业实践;
  3. GKD 实践——TRL 框架中 Generalized Knowledge Distillation Trainer 的设计思路与代码实操;
  4. 跨阶段蒸馏——GLM-5 的 on-policy cross-stage distillation,将蒸馏从模型间扩展到训练阶段间。

13.1 知识蒸馏基础

蒸馏的三种范式(数据蒸馏、Logits 蒸馏、RLAIF 蒸馏)及其优化目标的数学基础。

13.1.1 为什么需要蒸馏

部署成本与推理延迟:大模型部署成本极高——一个 671B 参数的模型即使用 INT4 量化,仍需数百 GB 显存,推理延迟可达数秒甚至更长。在实时对话、端侧推理、高并发 API 服务等场景中,直接部署大模型在成本和延迟上都不可行。蒸馏的首要目标是让轻量模型(1.5B~32B)在推理速度和资源消耗上达到可部署标准,同时尽可能保留大模型的能力。

模型能力迁移:蒸馏的本质是知识迁移——将大模型通过海量数据和计算获得的能力"搬运"到小模型中。这种迁移不仅包括表层的输入-输出映射,更包括大模型在概率分布层面编码的"暗知识"(Dark Knowledge):哪些词是合理的候选但最终没被选中、不同选项之间的相对置信度等。这些信息在 one-hot 标签中完全丢失,但在 Teacher 的 softmax 分布中得以保留。

知识压缩与教师-学生范式:蒸馏与量化、剪枝的根本区别在于:量化和剪枝压缩的是同一个模型的参数表示,而蒸馏是在不同模型之间迁移知识。Student 模型可以有完全不同的架构(不同层数、不同注意力头数、甚至不同的 Tokenizer),这带来了更大的灵活性。教师-学生范式(Teacher-Student Paradigm)的核心假设是:一个强大但昂贵的 Teacher 可以通过其输出(文本、logits 或偏好判断)来指导一个轻量但高效的 Student,使 Student 获得超越其独立训练能力上限的表现。

蒸馏的目标很直接——用大模型生成的高质量数据或概率分布,训练出轻量级模型,使其继承大模型的推理能力

13.1.2 三种蒸馏范式

范式学习对象代表方法特点
数据蒸馏(SFT-only)大模型生成的文本DeepSeek-R1-Distill简单直接,无需访问 Teacher 权重
Logits 蒸馏大模型的概率分布Qwen3 第二阶段、GKD信息量最大,需要 Teacher 前向传播
RLAIF 蒸馏大模型的偏好判断Constitutional AI对齐价值观,不直接传递知识

三种范式传递的信息量递减但灵活性递增。数据蒸馏只需要 Teacher 的输出文本(甚至可以用 API),Logits 蒸馏需要访问 Teacher 模型的权重和前向传播,RLAIF 蒸馏则只需要 Teacher 的偏好判断。

13.1.3 数据蒸馏:DeepSeek-R1-Distill

DeepSeek 采用纯数据蒸馏路线,流程极为简洁:

  1. 用 DeepSeek-R1(671B)生成 800k 条高质量思维链数据
  2. 在 Qwen 和 Llama 系列的多个规模(1.5B / 7B / 14B / 32B / 8B / 70B)上进行 SFT;
  3. 不包含任何 RL 阶段

DeepSeek 团队在实验中发现:直接蒸馏大模型的输出,比让小模型从头进行大规模 RL 效果更好。例如,蒸馏后的 DeepSeek-R1-Distill-Qwen-32B 性能显著优于同等规模自行进行 RL 训练的 DeepSeek-R1-Zero-Qwen-32B。

这一结论与第 12 章的分析一致:RL 对于小模型更多是"分布削尖"(在已有能力范围内集中概率质量),而蒸馏可以真正"扩展推理边界"——从 pass@k 曲线可以观察到,蒸馏后的模型在更高 k 值处仍能保持较好的覆盖率。

13.1.4 蒸馏的优化目标

经典知识蒸馏使用 Forward KL 散度作为优化目标:

LKD=DKL(PTeacherPStudent)=vVPT(v)logPT(v)PS(v)

直觉理解:Forward KL 要求 Student 在 Teacher 有概率质量的地方也给出足够高的概率。换言之,Teacher "认为可能"的词,Student 都不能忽略——这是 Mode Covering 行为。

与 Reverse KL DKL(PSPT) 对比:

方向行为效果
Forward KL DKL(PTPS)Mode CoveringStudent 覆盖 Teacher 的所有模式,可能过于分散
Reverse KL DKL(PSPT)Mode SeekingStudent 集中于 Teacher 最强的模式,可能遗漏长尾

在 PyTorch 中实现 Forward KL:

python
import torch.nn.functional as F

# Student 和 Teacher 的 logits 形状:[Batch, Seq, Vocab]
log_q = F.log_softmax(student_logits / temperature, dim=-1)  # Student 分布 Q
log_p = F.log_softmax(teacher_logits / temperature, dim=-1)  # Teacher 分布 P

# F.kl_div(input=log_Q, target=log_P) 计算 D_KL(P || Q) = sum(P * (log_P - log_Q))
kl = F.kl_div(log_q, log_p, reduction="none", log_target=True)
# kl 形状: [Batch, Seq, Vocab]

# 词表维度求和(一个 token 位置的 KL 是在整个词表上的和)
kl_per_token = kl.sum(dim=-1)   # [Batch, Seq]
# 对 batch 和 seq 取平均
kl_loss = kl_per_token.mean()

实现细节log_target=True 告诉 PyTorch target 参数(Teacher 分布)已经是 log 形式,避免内部做额外的 exp 再 log 操作,数值上更稳定。KL 在词表维度求和(不取平均),单位为 nats/token。


13.2 在线蒸馏(Online Distillation)

在线蒸馏解决分布偏移问题:GOLD 框架与 Qwen3 两阶段工业实践。

13.2.1 离线蒸馏的局限:分布偏移

传统离线蒸馏(Off-Policy Distillation)的数据由 Teacher 预先生成,Student 只在训练时见到 Teacher 产生的"标准答案"。这带来一个根本性问题——分布偏移(Distribution Shift)

类比:学生只背过参考答案。到了考场上自己写错了一个字(比如把"太阳"写成了"太阴"),因为从未见过这种出错场景,不知道如何恢复,后续的回答可能完全跑偏。

具体来说,推理时 Student 每一步都基于自己上一步的输出继续生成(自回归),但训练时 Student 只见过 Teacher 的输出序列。一旦 Student 生成了训练数据中未出现过的 token,后续预测就进入了训练分布之外的区域,误差逐步累积,最终导致生成质量严重下降。

这就是 Exposure Bias(暴露偏差)的经典表现。

13.2.2 在线蒸馏的核心思路

在线策略蒸馏(On-Policy Distillation)的关键改变:让 Student 自己生成数据,然后由 Teacher 对这些生成结果进行逐 token 的纠正

类比:学生自己做题,不管写得对还是错,老师都在旁边看着。学生写了"太阴",老师立即指出:"这里最大概率应该是'阳',你后面应该这样接……"。学生不仅学到了正确答案,还学会了如何从自己的错误中恢复。

核心流程分为三步:

步骤一:采样(Sampling)

由当前 Student 模型在线生成响应:

y^πθ(|x)

步骤二:对齐(Alignment)

将 Student 生成的序列 y^ 输入 Teacher 模型,获取 Teacher 对同一序列每个 token 位置的概率分布。

步骤三:优化(Optimization)

最小化 Student 与 Teacher 在该序列上的概率差距。

这一流程的本质是:Student 在自己的策略分布上采样,Teacher 在 Student 的分布上提供监督。训练数据始终来自 Student 的当前策略,从根本上消除了分布偏移。

13.2.3 跨 Tokenizer 蒸馏:GOLD 框架

当 Teacher 和 Student 使用不同的 Tokenizer 时(例如用 GPT-4 蒸馏 Llama),Token 序列无法逐位置对齐。Hugging Face 提出的 GOLD(General On-Policy Logit Distillation) 框架解决了这个问题。

核心思想:放弃 token 级别的一一对应,转而在**序列总对数似然(Total Log-Likelihood)**这个标量上进行比较。

具体过程

Student 使用 Tokenizer S 将生成的文本 y^ 切分为 N 个 token (u1,u2,,uN),序列总对数似然为:

logPθ(y^|x)=i=1Nlogπθ(ui|u<i,x)

Teacher 使用 Tokenizer T 将同一段文本切分为 M 个 token (v1,v2,,vM)(注意 NM,token ID 也完全不同),序列总对数似然为:

logPϕ(y^|x)=j=1Mlogπϕ(vj|v<j,x)

虽然 token 序列 uv 无法一一对应,但两个序列级对数似然均为标量,可以直接比较。

直觉:不管老师和学生怎么切分句子,最终我们只关心"对整句话的整体置信度"。好比两个翻译用不同的断句方式翻译同一段话,我们只比较最终翻译的整体质量,不纠结每个短语的对应关系。

GOLD 损失函数

LGOLD(θ)=ExD, y^πθ[Dseq(πθ,πϕ)]

在实际操作中,常见的简化形式有两种:

形式一:均方误差——直接让 Student 的序列似然匹配 Teacher 的序列似然:

L(θ)logPθ(y^|x)logPϕ(y^|x)2

形式二:RL 框架——将 Teacher 的概率视作奖励信号:

R(x,y^)=logPϕ(y^|x)βlogPθ(y^|x)

其中 β 控制 KL 惩罚的强度,防止 Student 偏离自身分布过远。这一形式与 RLHF 中的 PPO 目标函数结构一致,可以复用 RL 训练基础设施。

13.2.4 Qwen3 的两阶段蒸馏实践

Qwen3 采用 Strong-to-Weak 两阶段蒸馏,将 Qwen3-32B / Qwen3-235B-A22B 的能力迁移到 0.6B~30B 系列轻量模型。这是一个值得仔细分析的工业级蒸馏流水线。

第一阶段:离线蒸馏(Off-Policy Distillation)

  • 数据来源:使用 Teacher 模型(Qwen3-32B 或 Qwen3-235B-A22B)预先生成训练数据;
  • 数据特征:同时包含 /think(思考模式,长思维链推理)和 /no_think(非思考模式,直接回答)两类响应;
  • 训练方式:对 Student 模型进行标准 SFT;
  • 核心目标:让 Student 初步学会识别模式切换指令(根据 prompt 中的 /think/no_think 标签决定是否启动长思维链推理),并建立基本的推理能力分布。

第二阶段:在线蒸馏(On-Policy Logits Alignment)

  • Student 生成:采样一批 prompts,让经过第一阶段训练的 Student 在线生成响应(同时覆盖 /think/no_think 模式);
  • Logits 获取:将 Student 生成的序列送入 Teacher 模型,获取 Teacher 对每个 token 位置的概率分布;
  • 优化目标:最小化 KL 散度,迫使 Student 的输出概率分布逼近 Teacher:
minθDKL(PTeacher(|x,y^)Pθ(|x,y^))

两阶段的分工逻辑

阶段类型数据来源核心目标
第一阶段离线 SFTTeacher 预生成初始化分布,学会指令格式和基本推理模式
第二阶段在线 Logits 对齐Student 实时生成精细对齐概率分布,消除分布偏移,提升生成质量

为什么需要两阶段? 第一阶段的离线 SFT 让 Student 快速获得一个合理的初始分布——如果 Student 的初始生成质量太差,在线蒸馏阶段 Teacher 给出的修正信号会过于强烈,训练不稳定。第一阶段相当于"预热",第二阶段再精细打磨。


13.3 TRL 中的 GKD(Generalized Knowledge Distillation)实践

TRL 框架中 GKD Trainer 的设计思路与代码实操。

13.3.1 GKD 概述

TRL(Transformer Reinforcement Learning)库提供了 GKDTrainer,实现了 Generalized Knowledge Distillation(广义知识蒸馏)。

GKD 的"广义"体现在两个维度:

  1. 损失函数可选:Forward KL、Reverse KL、JSD 均支持;
  2. 数据来源可混合:Student 在线生成的数据和数据集中预存的数据可以按比例混合。

13.3.2 训练流程

GKD 的核心训练循环可以概括为以下四步:

python
for batch in dataloader:
    # 1. 数据选择:按概率 beta 决定使用 Student 在线生成还是数据集中的序列
    if random() < beta:
        # Student 在线采样(On-Policy Generation)
        sequences = student_model.generate(batch["prompts"], ...)
    else:
        # 使用数据集中预存的 Teacher 输出(Off-Policy)
        sequences = batch["completions"]

    # 2. 获取 Teacher 在相同序列上的 Logits(Teacher 不参与梯度计算)
    with torch.no_grad():
        teacher_logits = teacher_model(sequences).logits

    # 3. 获取 Student 的 Logits
    student_logits = student_model(sequences).logits

    # 4. 计算蒸馏损失并更新 Student
    loss = distillation_loss(teacher_logits, student_logits)
    loss.backward()
    optimizer.step()

关键设计点:参数 beta 控制在线生成与离线数据的混合比例。beta=1.0 表示完全在线蒸馏,beta=0.0 表示完全离线蒸馏,中间值则是两者的插值。这使得 GKD 可以平滑地在离线和在线模式之间过渡。

13.3.3 支持的损失函数

GKD Trainer 支持三种蒸馏损失:

损失类型公式行为适用场景
Forward KLDKL(PTPS)Mode Covering经典 KD,Student 覆盖 Teacher 所有模式
Reverse KLDKL(PSPT)Mode SeekingStudent 集中于 Teacher 最强模式,与 RL 目标一致
JSD12[DKL(PTM)+DKL(PSM)]折中对称损失,训练更稳定

其中 JSD 的中间分布 M=12(PT+PS)

如何选择? 如果目标是让 Student 尽可能全面地继承 Teacher 的能力分布,选 Forward KL;如果希望 Student 在少数任务上达到 Teacher 的峰值水平(可以牺牲长尾覆盖),选 Reverse KL;不确定时选 JSD,它兼顾两端且对称性带来更好的训练稳定性。

13.3.4 关键实现细节

词表空间的 KL 计算

在 LLM 中,每个 token 位置的 KL 散度是在整个词表 |V| 上求和:

python
# 正确做法:词表维度 sum,batch 和 seq 维度 mean
kl = F.kl_div(log_q, log_p, reduction="none", log_target=True)
# shape: [Batch, Seq, Vocab]
kl_per_token = kl.sum(dim=-1)   # [Batch, Seq] — 每个 token 位置一个标量
kl_loss = kl_per_token.mean()   # 对 batch 和 seq 取平均

常见错误是在词表维度用 mean 而非 sum,这会使损失值被 |V| 稀释,梯度过小,训练几乎没有效果。

显存估算

蒸馏需要同时存储 Teacher 和 Student 的完整词表 Logits,这是主要的显存瓶颈:

python
# 以 Batch=32, Seq=2048, Vocab=128000, fp32 为例
memory_gb = 32 * 2048 * 128000 * 4 / (1024**3)  # ≈ 31.25 GB(单侧)
# Teacher + Student 两侧 Logits:约 62.5 GB

应对策略:

  • 使用 bf16/fp16 存储 Logits,显存减半;
  • 开启梯度检查点(Gradient Checkpointing),用计算换显存;
  • 减小 batch size 或序列长度;
  • 对 Teacher 使用量化推理(INT8/INT4)。

温度参数

蒸馏时通常使用温度 T>1 对 Logits 进行 softmax:

PT(v)=exp(zv/T)vexp(zv/T)

较高的温度使分布更平滑,暴露 Teacher 在低概率区域的偏好信息,帮助 Student 学到更丰富的知识。GKD 中通过 temperature 参数控制。

13.3.5 代码示例

以下是使用 TRL GKDTrainer 的完整代码框架:

python
from trl import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

# -------- 加载模型 --------
teacher_model_name = "Qwen/Qwen3-32B"
student_model_name = "Qwen/Qwen3-7B"

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name, torch_dtype=torch.bfloat16
)
student_model = AutoModelForCausalLM.from_pretrained(
    student_model_name, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(student_model_name)

# -------- 配置 GKD --------
gkd_config = GKDConfig(
    output_dir="./gkd-output",
    temperature=0.9,        # 蒸馏温度,>1 使分布更平滑
    lmbda=0.5,              # SFT 损失与蒸馏损失的混合比例
                            # loss = lmbda * L_distill + (1 - lmbda) * L_sft
    beta=0.5,               # Student 在线生成 vs 数据集离线数据的比例
    max_new_tokens=512,     # Student 在线生成的最大 token 数
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    num_train_epochs=3,
    bf16=True,
)

# -------- 训练 --------
trainer = GKDTrainer(
    teacher_model=teacher_model,
    model=student_model,
    args=gkd_config,
    train_dataset=dataset,       # 包含 prompts 和可选的 completions
    processing_class=tokenizer,
)
trainer.train()

参数解读

  • lmbda:控制蒸馏损失与标准 SFT 损失的混合。lmbda=1.0 为纯蒸馏,lmbda=0.0 为纯 SFT。混合训练有助于 Student 在学习 Teacher 分布的同时保持对原始数据的拟合。
  • beta:控制在线/离线数据比例。beta=0.5 表示一半 batch 使用 Student 在线生成,另一半使用数据集中的预存序列。
  • temperature:蒸馏温度。经验上 0.7~1.0 较常用,具体取决于 Teacher 和 Student 的能力差距。

13.3.6 方法对比

方法数据来源损失类型跨 Tokenizer适用场景
离线数据蒸馏(SFT)Teacher 预生成文本Cross-Entropy天然支持资源受限,无需 Teacher 权重
在线 Logits 蒸馏(GKD)Student 在线生成Token-level KL/JSD需相同 Tokenizer相同架构族,追求最高质量
序列级在线蒸馏(GOLD)Student 在线生成Sequence-level KL支持跨架构蒸馏,缓解分布偏移
RLAIF 蒸馏Student 在线生成偏好对比天然支持对齐主观价值观和安全性

13.4 On-Policy 跨阶段蒸馏:GLM-5 的实践

GLM-5 的跨阶段蒸馏实践:如何在 RL 训练中持续从 Teacher 获取 On-Policy 指导。

13.4.1 动机:跨训练阶段的蒸馏

前面介绍的在线蒸馏(13.2)和 GKD(13.3)都聚焦于一个固定的 Teacher-Student 对。GLM-5 提出了 On-Policy Cross-Stage Distillation(在线跨阶段蒸馏),将蒸馏的视角从"模型间"扩展到"训练阶段间"——用模型在后期训练阶段(如 RL 阶段)的能力,反向蒸馏回早期阶段(如 SFT 阶段),形成跨阶段的知识回传。

13.4.2 与传统在线蒸馏的区别

维度传统在线蒸馏跨阶段蒸馏
Teacher独立的大模型同一模型的后期训练阶段版本
Student独立的小模型同一模型的早期训练阶段版本
知识流向后期 早期(同一模型的不同训练时刻)
核心目标模型压缩跨阶段知识融合,提升整体训练效率

13.4.3 跨阶段蒸馏的流程

跨阶段蒸馏的核心流程如下:

  1. 阶段 A(SFT):模型完成 SFT 训练,获得 πSFT
  2. 阶段 B(RL):在 πSFT 基础上进行 RL 训练,获得 πRL——此时 πRL 在推理任务上更强,但可能丧失部分通用能力;
  3. 反向蒸馏:用 πRL 作为 Teacher,对 πSFT 阶段的模型进行 on-policy 蒸馏——Student 在自身分布上采样,Teacher(πRL)提供 logits 级别的监督;
  4. 迭代:蒸馏后的模型重新进入下一轮 RL 训练,形成"SFT RL 蒸馏回 SFT RL"的迭代循环。

这一流程的关键洞察是:RL 阶段习得的推理能力可以通过蒸馏"固化"到 SFT 阶段的模型中,使得下一轮 RL 从一个更强的起点出发。这与第 11 章讨论的 RFT 抗遗忘性相呼应——RL 通过自我轨迹采样保护已有能力,而跨阶段蒸馏则通过将 RL 增益回注到基础模型中,进一步巩固这种能力保留。

13.4.4 On-Policy 的必要性

跨阶段蒸馏同样采用 on-policy 策略(Student 自己生成数据),原因与 13.2 节相同:避免分布偏移。如果用 Teacher(πRL)生成的数据做离线蒸馏,Student 在推理时生成的序列可能偏离训练分布,导致误差累积。On-policy 采样确保 Student 在自身策略分布上接受 Teacher 的纠正,训练和推理时的数据分布一致。


本章小结

  1. 蒸馏优于小模型自主 RL:DeepSeek-R1 的实验表明,直接蒸馏大模型输出比让小模型自行 RL 效果更好。RL 只能在已有能力范围内"削尖"分布,而 Teacher 的 Logits 蒸馏可以真正"扩展"Student 的推理边界。

  2. 离线 vs 在线的本质差异:离线蒸馏简单高效,但 Student 只见过 Teacher 的输出分布,推理时面临分布偏移;在线蒸馏让 Student 在自身策略分布上采样并接受 Teacher 纠正,从根本上消除 Exposure Bias,代价是需要 Teacher 在线推理。

  3. 跨 Tokenizer 的解决方案:GOLD 框架将对齐粒度从 token 级提升到序列级,通过比较序列总对数似然这一标量,绕开了 Tokenizer 不匹配的问题。

  4. Forward KL vs Reverse KL 的选择:经典 KD 使用 Forward KL(Mode Covering),保证 Student 覆盖 Teacher 所有模式;RL 框架下的蒸馏倾向 Reverse KL(Mode Seeking),与探索驱动的训练更匹配。JSD 是两者的对称折中。

  5. 工程挑战:词表 Logits 的显存开销(两侧合计可达数十 GB)、蒸馏温度的调节、在线生成的吞吐瓶颈,是蒸馏从论文到生产的主要障碍。GKD Trainer 通过 beta 参数在离线和在线之间灵活切换,提供了一条渐进式的落地路径。

  6. 跨阶段蒸馏:GLM-5 的 on-policy cross-stage distillation 将蒸馏从"模型间"扩展到"训练阶段间",通过将 RL 阶段的推理增益蒸馏回 SFT 阶段的模型,形成迭代提升的训练循环。


延伸阅读