第 13 章:知识蒸馏
本章导读
知识蒸馏(Knowledge Distillation, KD)是将大模型的能力迁移到小模型的核心技术。当 DeepSeek-R1(671B)、Qwen3-235B 等超大模型展现出强大的推理能力时,如何让 1.5B~32B 的轻量模型继承这些能力,就是蒸馏要解决的问题。
本章围绕四个层次展开:
- 基础原理——蒸馏的优化目标、三种范式及其差异;
- 在线蒸馏——分布偏移问题的根因,GOLD 框架如何实现跨 Tokenizer 的序列级蒸馏,以及 Qwen3 的两阶段工业实践;
- GKD 实践——TRL 框架中 Generalized Knowledge Distillation Trainer 的设计思路与代码实操;
- 跨阶段蒸馏——GLM-5 的 on-policy cross-stage distillation,将蒸馏从模型间扩展到训练阶段间。
13.1 知识蒸馏基础
蒸馏的三种范式(数据蒸馏、Logits 蒸馏、RLAIF 蒸馏)及其优化目标的数学基础。
13.1.1 为什么需要蒸馏
部署成本与推理延迟:大模型部署成本极高——一个 671B 参数的模型即使用 INT4 量化,仍需数百 GB 显存,推理延迟可达数秒甚至更长。在实时对话、端侧推理、高并发 API 服务等场景中,直接部署大模型在成本和延迟上都不可行。蒸馏的首要目标是让轻量模型(1.5B~32B)在推理速度和资源消耗上达到可部署标准,同时尽可能保留大模型的能力。
模型能力迁移:蒸馏的本质是知识迁移——将大模型通过海量数据和计算获得的能力"搬运"到小模型中。这种迁移不仅包括表层的输入-输出映射,更包括大模型在概率分布层面编码的"暗知识"(Dark Knowledge):哪些词是合理的候选但最终没被选中、不同选项之间的相对置信度等。这些信息在 one-hot 标签中完全丢失,但在 Teacher 的 softmax 分布中得以保留。
知识压缩与教师-学生范式:蒸馏与量化、剪枝的根本区别在于:量化和剪枝压缩的是同一个模型的参数表示,而蒸馏是在不同模型之间迁移知识。Student 模型可以有完全不同的架构(不同层数、不同注意力头数、甚至不同的 Tokenizer),这带来了更大的灵活性。教师-学生范式(Teacher-Student Paradigm)的核心假设是:一个强大但昂贵的 Teacher 可以通过其输出(文本、logits 或偏好判断)来指导一个轻量但高效的 Student,使 Student 获得超越其独立训练能力上限的表现。
蒸馏的目标很直接——用大模型生成的高质量数据或概率分布,训练出轻量级模型,使其继承大模型的推理能力。
13.1.2 三种蒸馏范式
| 范式 | 学习对象 | 代表方法 | 特点 |
|---|---|---|---|
| 数据蒸馏(SFT-only) | 大模型生成的文本 | DeepSeek-R1-Distill | 简单直接,无需访问 Teacher 权重 |
| Logits 蒸馏 | 大模型的概率分布 | Qwen3 第二阶段、GKD | 信息量最大,需要 Teacher 前向传播 |
| RLAIF 蒸馏 | 大模型的偏好判断 | Constitutional AI | 对齐价值观,不直接传递知识 |
三种范式传递的信息量递减但灵活性递增。数据蒸馏只需要 Teacher 的输出文本(甚至可以用 API),Logits 蒸馏需要访问 Teacher 模型的权重和前向传播,RLAIF 蒸馏则只需要 Teacher 的偏好判断。
13.1.3 数据蒸馏:DeepSeek-R1-Distill
DeepSeek 采用纯数据蒸馏路线,流程极为简洁:
- 用 DeepSeek-R1(671B)生成 800k 条高质量思维链数据;
- 在 Qwen 和 Llama 系列的多个规模(1.5B / 7B / 14B / 32B / 8B / 70B)上进行 SFT;
- 不包含任何 RL 阶段。
DeepSeek 团队在实验中发现:直接蒸馏大模型的输出,比让小模型从头进行大规模 RL 效果更好。例如,蒸馏后的 DeepSeek-R1-Distill-Qwen-32B 性能显著优于同等规模自行进行 RL 训练的 DeepSeek-R1-Zero-Qwen-32B。
这一结论与第 12 章的分析一致:RL 对于小模型更多是"分布削尖"(在已有能力范围内集中概率质量),而蒸馏可以真正"扩展推理边界"——从 pass@k 曲线可以观察到,蒸馏后的模型在更高 k 值处仍能保持较好的覆盖率。
13.1.4 蒸馏的优化目标
经典知识蒸馏使用 Forward KL 散度作为优化目标:
直觉理解:Forward KL 要求 Student 在 Teacher 有概率质量的地方也给出足够高的概率。换言之,Teacher "认为可能"的词,Student 都不能忽略——这是 Mode Covering 行为。
与 Reverse KL
| 方向 | 行为 | 效果 |
|---|---|---|
| Forward KL | Mode Covering | Student 覆盖 Teacher 的所有模式,可能过于分散 |
| Reverse KL | Mode Seeking | Student 集中于 Teacher 最强的模式,可能遗漏长尾 |
在 PyTorch 中实现 Forward KL:
import torch.nn.functional as F
# Student 和 Teacher 的 logits 形状:[Batch, Seq, Vocab]
log_q = F.log_softmax(student_logits / temperature, dim=-1) # Student 分布 Q
log_p = F.log_softmax(teacher_logits / temperature, dim=-1) # Teacher 分布 P
# F.kl_div(input=log_Q, target=log_P) 计算 D_KL(P || Q) = sum(P * (log_P - log_Q))
kl = F.kl_div(log_q, log_p, reduction="none", log_target=True)
# kl 形状: [Batch, Seq, Vocab]
# 词表维度求和(一个 token 位置的 KL 是在整个词表上的和)
kl_per_token = kl.sum(dim=-1) # [Batch, Seq]
# 对 batch 和 seq 取平均
kl_loss = kl_per_token.mean()实现细节:
log_target=True告诉 PyTorch target 参数(Teacher 分布)已经是 log 形式,避免内部做额外的 exp 再 log 操作,数值上更稳定。KL 在词表维度求和(不取平均),单位为 nats/token。
13.2 在线蒸馏(Online Distillation)
在线蒸馏解决分布偏移问题:GOLD 框架与 Qwen3 两阶段工业实践。
13.2.1 离线蒸馏的局限:分布偏移
传统离线蒸馏(Off-Policy Distillation)的数据由 Teacher 预先生成,Student 只在训练时见到 Teacher 产生的"标准答案"。这带来一个根本性问题——分布偏移(Distribution Shift)。
类比:学生只背过参考答案。到了考场上自己写错了一个字(比如把"太阳"写成了"太阴"),因为从未见过这种出错场景,不知道如何恢复,后续的回答可能完全跑偏。
具体来说,推理时 Student 每一步都基于自己上一步的输出继续生成(自回归),但训练时 Student 只见过 Teacher 的输出序列。一旦 Student 生成了训练数据中未出现过的 token,后续预测就进入了训练分布之外的区域,误差逐步累积,最终导致生成质量严重下降。
这就是 Exposure Bias(暴露偏差)的经典表现。
13.2.2 在线蒸馏的核心思路
在线策略蒸馏(On-Policy Distillation)的关键改变:让 Student 自己生成数据,然后由 Teacher 对这些生成结果进行逐 token 的纠正。
类比:学生自己做题,不管写得对还是错,老师都在旁边看着。学生写了"太阴",老师立即指出:"这里最大概率应该是'阳',你后面应该这样接……"。学生不仅学到了正确答案,还学会了如何从自己的错误中恢复。
核心流程分为三步:
步骤一:采样(Sampling)
由当前 Student 模型在线生成响应:
步骤二:对齐(Alignment)
将 Student 生成的序列
步骤三:优化(Optimization)
最小化 Student 与 Teacher 在该序列上的概率差距。
这一流程的本质是:Student 在自己的策略分布上采样,Teacher 在 Student 的分布上提供监督。训练数据始终来自 Student 的当前策略,从根本上消除了分布偏移。
13.2.3 跨 Tokenizer 蒸馏:GOLD 框架
当 Teacher 和 Student 使用不同的 Tokenizer 时(例如用 GPT-4 蒸馏 Llama),Token 序列无法逐位置对齐。Hugging Face 提出的 GOLD(General On-Policy Logit Distillation) 框架解决了这个问题。
核心思想:放弃 token 级别的一一对应,转而在**序列总对数似然(Total Log-Likelihood)**这个标量上进行比较。
具体过程:
Student 使用 Tokenizer
Teacher 使用 Tokenizer
虽然 token 序列
直觉:不管老师和学生怎么切分句子,最终我们只关心"对整句话的整体置信度"。好比两个翻译用不同的断句方式翻译同一段话,我们只比较最终翻译的整体质量,不纠结每个短语的对应关系。
GOLD 损失函数:
在实际操作中,常见的简化形式有两种:
形式一:均方误差——直接让 Student 的序列似然匹配 Teacher 的序列似然:
形式二:RL 框架——将 Teacher 的概率视作奖励信号:
其中
13.2.4 Qwen3 的两阶段蒸馏实践
Qwen3 采用 Strong-to-Weak 两阶段蒸馏,将 Qwen3-32B / Qwen3-235B-A22B 的能力迁移到 0.6B~30B 系列轻量模型。这是一个值得仔细分析的工业级蒸馏流水线。
第一阶段:离线蒸馏(Off-Policy Distillation)
- 数据来源:使用 Teacher 模型(Qwen3-32B 或 Qwen3-235B-A22B)预先生成训练数据;
- 数据特征:同时包含
/think(思考模式,长思维链推理)和/no_think(非思考模式,直接回答)两类响应; - 训练方式:对 Student 模型进行标准 SFT;
- 核心目标:让 Student 初步学会识别模式切换指令(根据 prompt 中的
/think或/no_think标签决定是否启动长思维链推理),并建立基本的推理能力分布。
第二阶段:在线蒸馏(On-Policy Logits Alignment)
- Student 生成:采样一批 prompts,让经过第一阶段训练的 Student 在线生成响应(同时覆盖
/think和/no_think模式); - Logits 获取:将 Student 生成的序列送入 Teacher 模型,获取 Teacher 对每个 token 位置的概率分布;
- 优化目标:最小化 KL 散度,迫使 Student 的输出概率分布逼近 Teacher:
两阶段的分工逻辑:
| 阶段 | 类型 | 数据来源 | 核心目标 |
|---|---|---|---|
| 第一阶段 | 离线 SFT | Teacher 预生成 | 初始化分布,学会指令格式和基本推理模式 |
| 第二阶段 | 在线 Logits 对齐 | Student 实时生成 | 精细对齐概率分布,消除分布偏移,提升生成质量 |
为什么需要两阶段? 第一阶段的离线 SFT 让 Student 快速获得一个合理的初始分布——如果 Student 的初始生成质量太差,在线蒸馏阶段 Teacher 给出的修正信号会过于强烈,训练不稳定。第一阶段相当于"预热",第二阶段再精细打磨。
13.3 TRL 中的 GKD(Generalized Knowledge Distillation)实践
TRL 框架中 GKD Trainer 的设计思路与代码实操。
13.3.1 GKD 概述
TRL(Transformer Reinforcement Learning)库提供了 GKDTrainer,实现了 Generalized Knowledge Distillation(广义知识蒸馏)。
- 文档:https://huggingface.co/docs/trl/main/en/gkd_trainer
- 核心定位:将在线蒸馏统一到一个通用框架,支持多种蒸馏损失和灵活的数据混合策略。
GKD 的"广义"体现在两个维度:
- 损失函数可选:Forward KL、Reverse KL、JSD 均支持;
- 数据来源可混合:Student 在线生成的数据和数据集中预存的数据可以按比例混合。
13.3.2 训练流程
GKD 的核心训练循环可以概括为以下四步:
for batch in dataloader:
# 1. 数据选择:按概率 beta 决定使用 Student 在线生成还是数据集中的序列
if random() < beta:
# Student 在线采样(On-Policy Generation)
sequences = student_model.generate(batch["prompts"], ...)
else:
# 使用数据集中预存的 Teacher 输出(Off-Policy)
sequences = batch["completions"]
# 2. 获取 Teacher 在相同序列上的 Logits(Teacher 不参与梯度计算)
with torch.no_grad():
teacher_logits = teacher_model(sequences).logits
# 3. 获取 Student 的 Logits
student_logits = student_model(sequences).logits
# 4. 计算蒸馏损失并更新 Student
loss = distillation_loss(teacher_logits, student_logits)
loss.backward()
optimizer.step()关键设计点:参数 beta 控制在线生成与离线数据的混合比例。beta=1.0 表示完全在线蒸馏,beta=0.0 表示完全离线蒸馏,中间值则是两者的插值。这使得 GKD 可以平滑地在离线和在线模式之间过渡。
13.3.3 支持的损失函数
GKD Trainer 支持三种蒸馏损失:
| 损失类型 | 公式 | 行为 | 适用场景 |
|---|---|---|---|
| Forward KL | Mode Covering | 经典 KD,Student 覆盖 Teacher 所有模式 | |
| Reverse KL | Mode Seeking | Student 集中于 Teacher 最强模式,与 RL 目标一致 | |
| JSD | 折中 | 对称损失,训练更稳定 |
其中 JSD 的中间分布
如何选择? 如果目标是让 Student 尽可能全面地继承 Teacher 的能力分布,选 Forward KL;如果希望 Student 在少数任务上达到 Teacher 的峰值水平(可以牺牲长尾覆盖),选 Reverse KL;不确定时选 JSD,它兼顾两端且对称性带来更好的训练稳定性。
13.3.4 关键实现细节
词表空间的 KL 计算
在 LLM 中,每个 token 位置的 KL 散度是在整个词表
# 正确做法:词表维度 sum,batch 和 seq 维度 mean
kl = F.kl_div(log_q, log_p, reduction="none", log_target=True)
# shape: [Batch, Seq, Vocab]
kl_per_token = kl.sum(dim=-1) # [Batch, Seq] — 每个 token 位置一个标量
kl_loss = kl_per_token.mean() # 对 batch 和 seq 取平均常见错误是在词表维度用 mean 而非 sum,这会使损失值被
显存估算
蒸馏需要同时存储 Teacher 和 Student 的完整词表 Logits,这是主要的显存瓶颈:
# 以 Batch=32, Seq=2048, Vocab=128000, fp32 为例
memory_gb = 32 * 2048 * 128000 * 4 / (1024**3) # ≈ 31.25 GB(单侧)
# Teacher + Student 两侧 Logits:约 62.5 GB应对策略:
- 使用 bf16/fp16 存储 Logits,显存减半;
- 开启梯度检查点(Gradient Checkpointing),用计算换显存;
- 减小 batch size 或序列长度;
- 对 Teacher 使用量化推理(INT8/INT4)。
温度参数
蒸馏时通常使用温度
较高的温度使分布更平滑,暴露 Teacher 在低概率区域的偏好信息,帮助 Student 学到更丰富的知识。GKD 中通过 temperature 参数控制。
13.3.5 代码示例
以下是使用 TRL GKDTrainer 的完整代码框架:
from trl import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
# -------- 加载模型 --------
teacher_model_name = "Qwen/Qwen3-32B"
student_model_name = "Qwen/Qwen3-7B"
teacher_model = AutoModelForCausalLM.from_pretrained(
teacher_model_name, torch_dtype=torch.bfloat16
)
student_model = AutoModelForCausalLM.from_pretrained(
student_model_name, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(student_model_name)
# -------- 配置 GKD --------
gkd_config = GKDConfig(
output_dir="./gkd-output",
temperature=0.9, # 蒸馏温度,>1 使分布更平滑
lmbda=0.5, # SFT 损失与蒸馏损失的混合比例
# loss = lmbda * L_distill + (1 - lmbda) * L_sft
beta=0.5, # Student 在线生成 vs 数据集离线数据的比例
max_new_tokens=512, # Student 在线生成的最大 token 数
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=1e-5,
num_train_epochs=3,
bf16=True,
)
# -------- 训练 --------
trainer = GKDTrainer(
teacher_model=teacher_model,
model=student_model,
args=gkd_config,
train_dataset=dataset, # 包含 prompts 和可选的 completions
processing_class=tokenizer,
)
trainer.train()参数解读:
lmbda:控制蒸馏损失与标准 SFT 损失的混合。lmbda=1.0为纯蒸馏,lmbda=0.0为纯 SFT。混合训练有助于 Student 在学习 Teacher 分布的同时保持对原始数据的拟合。beta:控制在线/离线数据比例。beta=0.5表示一半 batch 使用 Student 在线生成,另一半使用数据集中的预存序列。temperature:蒸馏温度。经验上 0.7~1.0 较常用,具体取决于 Teacher 和 Student 的能力差距。
13.3.6 方法对比
| 方法 | 数据来源 | 损失类型 | 跨 Tokenizer | 适用场景 |
|---|---|---|---|---|
| 离线数据蒸馏(SFT) | Teacher 预生成文本 | Cross-Entropy | 天然支持 | 资源受限,无需 Teacher 权重 |
| 在线 Logits 蒸馏(GKD) | Student 在线生成 | Token-level KL/JSD | 需相同 Tokenizer | 相同架构族,追求最高质量 |
| 序列级在线蒸馏(GOLD) | Student 在线生成 | Sequence-level KL | 支持 | 跨架构蒸馏,缓解分布偏移 |
| RLAIF 蒸馏 | Student 在线生成 | 偏好对比 | 天然支持 | 对齐主观价值观和安全性 |
13.4 On-Policy 跨阶段蒸馏:GLM-5 的实践
GLM-5 的跨阶段蒸馏实践:如何在 RL 训练中持续从 Teacher 获取 On-Policy 指导。
13.4.1 动机:跨训练阶段的蒸馏
前面介绍的在线蒸馏(13.2)和 GKD(13.3)都聚焦于一个固定的 Teacher-Student 对。GLM-5 提出了 On-Policy Cross-Stage Distillation(在线跨阶段蒸馏),将蒸馏的视角从"模型间"扩展到"训练阶段间"——用模型在后期训练阶段(如 RL 阶段)的能力,反向蒸馏回早期阶段(如 SFT 阶段),形成跨阶段的知识回传。
13.4.2 与传统在线蒸馏的区别
| 维度 | 传统在线蒸馏 | 跨阶段蒸馏 |
|---|---|---|
| Teacher | 独立的大模型 | 同一模型的后期训练阶段版本 |
| Student | 独立的小模型 | 同一模型的早期训练阶段版本 |
| 知识流向 | 大 | 后期 |
| 核心目标 | 模型压缩 | 跨阶段知识融合,提升整体训练效率 |
13.4.3 跨阶段蒸馏的流程
跨阶段蒸馏的核心流程如下:
- 阶段 A(SFT):模型完成 SFT 训练,获得
; - 阶段 B(RL):在
基础上进行 RL 训练,获得 ——此时 在推理任务上更强,但可能丧失部分通用能力; - 反向蒸馏:用
作为 Teacher,对 阶段的模型进行 on-policy 蒸馏——Student 在自身分布上采样,Teacher( )提供 logits 级别的监督; - 迭代:蒸馏后的模型重新进入下一轮 RL 训练,形成"SFT
RL 蒸馏回 SFT RL"的迭代循环。
这一流程的关键洞察是:RL 阶段习得的推理能力可以通过蒸馏"固化"到 SFT 阶段的模型中,使得下一轮 RL 从一个更强的起点出发。这与第 11 章讨论的 RFT 抗遗忘性相呼应——RL 通过自我轨迹采样保护已有能力,而跨阶段蒸馏则通过将 RL 增益回注到基础模型中,进一步巩固这种能力保留。
13.4.4 On-Policy 的必要性
跨阶段蒸馏同样采用 on-policy 策略(Student 自己生成数据),原因与 13.2 节相同:避免分布偏移。如果用 Teacher(
本章小结
蒸馏优于小模型自主 RL:DeepSeek-R1 的实验表明,直接蒸馏大模型输出比让小模型自行 RL 效果更好。RL 只能在已有能力范围内"削尖"分布,而 Teacher 的 Logits 蒸馏可以真正"扩展"Student 的推理边界。
离线 vs 在线的本质差异:离线蒸馏简单高效,但 Student 只见过 Teacher 的输出分布,推理时面临分布偏移;在线蒸馏让 Student 在自身策略分布上采样并接受 Teacher 纠正,从根本上消除 Exposure Bias,代价是需要 Teacher 在线推理。
跨 Tokenizer 的解决方案:GOLD 框架将对齐粒度从 token 级提升到序列级,通过比较序列总对数似然这一标量,绕开了 Tokenizer 不匹配的问题。
Forward KL vs Reverse KL 的选择:经典 KD 使用 Forward KL(Mode Covering),保证 Student 覆盖 Teacher 所有模式;RL 框架下的蒸馏倾向 Reverse KL(Mode Seeking),与探索驱动的训练更匹配。JSD 是两者的对称折中。
工程挑战:词表 Logits 的显存开销(两侧合计可达数十 GB)、蒸馏温度的调节、在线生成的吞吐瓶颈,是蒸馏从论文到生产的主要障碍。GKD Trainer 通过
beta参数在离线和在线之间灵活切换,提供了一条渐进式的落地路径。跨阶段蒸馏:GLM-5 的 on-policy cross-stage distillation 将蒸馏从"模型间"扩展到"训练阶段间",通过将 RL 阶段的推理增益蒸馏回 SFT 阶段的模型,形成迭代提升的训练循环。
延伸阅读
- DeepSeek-R1 技术报告:https://arxiv.org/abs/2501.12948 — 数据蒸馏路线的完整实验分析
- Qwen3 技术博客(蒸馏章节):Qwen 官方博客 — Strong-to-Weak 两阶段蒸馏的工业实践
- GOLD 在线蒸馏:https://thinkingmachines.ai/blog/on-policy-distillation/ — 跨 Tokenizer 蒸馏的理论与实验
- Hugging Face On-Policy Distillation Demo:https://huggingface.co/spaces/HuggingFaceH4/on-policy-distillation
- TRL GKD Trainer 文档:https://huggingface.co/docs/trl/main/en/gkd_trainer — API 参考与配置详解
- GLM-5 Cross-Stage Distillation:在线跨阶段蒸馏的前沿探索(详见 13.4 节)