10.8 训练加速技巧
前面章节讨论了数据并行、张量并行、流水线并行等分布式并行策略,它们通过将计算分散到多个设备上来突破单卡的显存和算力限制。然而,即使在单卡层面,仍然存在大量"隐藏的性能税":冗余的激活值显存占用、低精度硬件能力的闲置、算子间频繁的内存读写、以及自回归解码中逐 token 预测的串行瓶颈。本节介绍四种正交于分布式并行的训练加速技巧——梯度检查点、FP8 混合精度训练、通算融合和多 Token 预测并行——它们分别从显存、精度、访存和解码四个维度挖掘加速空间。
10.8.1 梯度检查点(激活值重计算)
问题背景。 标准的反向传播算法要求在前向传播时保存所有中间层的激活值,以便反向传播计算梯度时使用。对于一个
核心思想。 梯度检查点(Gradient Checkpointing),也称激活检查点(Activation Checkpointing)或激活值重计算(Activation Recomputation),是一种经典的以算换存技术。其基本策略是:前向传播时不再保存所有中间层的激活值,而是只保存若干精心选择的"检查点"层的激活值,丢弃其余层的激活值;反向传播需要某个被丢弃的激活值时,从最近的检查点重新执行一段前向传播来即时重算。
以算换存的定量分析。 设网络共有
- 显存开销:只需存储
个检查点的激活值,加上反向传播时当前段内重计算产生的临时激活值(最多 个层),总显存为 。 - 最优检查点数:对
求最小值,令导数为零得 ,此时显存开销为 。 - 计算代价:每个检查点段内的前向传播被额外执行一次。原始前向传播的总计算量为
,反向传播约为 (每层需要计算关于输入和参数的两组梯度),加上重计算的额外前向传播 ,总计算量从 增加到约 ,即增加约 33% 的计算量。
这意味着:一个
实践要点。 现代训练框架(如 PyTorch 的 torch.utils.checkpoint)已将梯度检查点封装为易用的 API。实际使用时,检查点的粒度通常以 Transformer 层为单位——每隔若干层设置一个检查点。更精细的策略(如选择性重计算)会根据每层激活值的大小和重计算成本做差异化决策:对于占用显存大但重计算代价低的操作(如 LayerNorm、Dropout)优先丢弃,对于计算昂贵的操作(如矩阵乘法的输出)优先保留。
梯度检查点不仅能节省显存,在某些场景下甚至能提升端到端速度。当某个操作是内存密集型(memory-bound) 时——GPU 的计算单元大量时间在等待数据从 HBM 中读取——重新计算该操作可能比从 HBM 读取已保存的激活值更快。FlashAttention 的反向传播就利用了这一原理:它不存储
10.8.2 FP8 混合精度训练
从 BF16 到 FP8 的精度演进。 混合精度训练的核心思想是用低精度格式执行大部分计算和存储,用高精度格式维护关键状态(如主权重和梯度累加器)。BF16 混合精度已经成为大模型训练的事实标准:它保留了与 FP32 相同的 8 位指数,拥有相同的动态范围,有效避免了 FP16 容易发生的上溢和下溢问题,同时将内存占用和带宽需求减半。
FP8 将精度进一步压缩到 8 位,理论上可以在 BF16 的基础上再将内存和带宽需求减半,并且在 NVIDIA Hopper(H100)及后续架构上,Tensor Core 的 FP8 吞吐量是 BF16 的
E4M3 与 E5M2 双格式策略。 FP8 标准定义了两种互补的编码格式:
| 格式 | 符号位 | 指数位 | 尾数位 | 动态范围 | 精度 | 适用场景 |
|---|---|---|---|---|---|---|
| E4M3 | 1 | 4 | 3 | 较小(±448) | 较高 | 前向传播的权重与激活值 |
| E5M2 | 1 | 5 | 2 | 较大(±57344) | 较低 | 反向传播的梯度 |
这一双格式策略的设计逻辑是:
- 前向传播使用 E4M3。前向传播中的权重和激活值经过训练后数值分布相对集中,对精度的需求高于对动态范围的需求。E4M3 的 3 位尾数提供了 FP8 中最高的精度,足以维持前向计算的准确性。
- 反向传播使用 E5M2。梯度的数值分布范围极大——从接近零的小梯度到偶尔出现的梯度尖峰——对动态范围的需求远高于对精度的需求。E5M2 的 5 位指数提供了与 FP16 相同的动态范围,可以有效避免梯度的上溢和下溢。
在这两种格式之下,Tensor Core 内部的矩阵乘法累加器仍然使用 FP32,确保累加过程中不发生精度损失。整体计算链路为:FP8 输入 → FP8 乘法 → FP32 累加 → FP8/BF16 输出。
逐张量缩放(Per-Tensor Scaling)。 由于 FP8 的表示范围极其有限,直接将 BF16 数值截断到 FP8 会导致大量信息丢失。因此,FP8 训练必须配合缩放因子(Scaling Factor) 使用:在将张量转换为 FP8 之前,先除以一个缩放因子将数值映射到 FP8 的有效表示范围内;在使用 FP8 结果时再乘回缩放因子恢复原始量级。
缩放因子的选择通常基于延迟缩放(Delayed Scaling) 策略:使用前一个训练步中该张量的最大绝对值来计算当前步的缩放因子。这避免了在当前步额外遍历整个张量来统计最大值的开销,代价是缩放因子滞后一步,但实践中对训练稳定性的影响可以忽略。
实际收益与局限。 在 H100 及更新的架构上,FP8 训练可以将矩阵乘法的吞吐量提升至 BF16 的近
10.8.3 通算融合
算子融合的动机。 在 PyTorch 等框架的默认执行模式下,每个数学操作(加、乘、激活函数、归一化等)都会启动一个独立的 CUDA 核函数。每个核函数都需要从 HBM 读取输入、将结果写回 HBM,供下一个核函数读取。对于逐元素操作(如 ReLU、残差加法)和归约操作(如 LayerNorm),计算量极小但内存读写量大——它们是典型的内存密集型操作。此时,GPU 的计算单元大量空闲,性能瓶颈完全在内存带宽上。
算子融合(Operator Fusion/Kernel Fusion) 通过将多个连续操作合并为一个 CUDA 核函数来消除中间内存读写。融合后的核函数只需一次从 HBM 读取输入,所有中间结果保留在 SM 的寄存器或共享内存(SRAM)中完成计算,最终结果一次写回 HBM。这样做有两个直接收益:(1)大幅减少 HBM 访问次数,提升有效内存带宽利用率;(2)消除核函数启动开销——每次启动 CUDA 核函数都有固定的驱动程序开销,融合后启动次数减少。
FlashAttention:通算融合的标杆案例。 标准注意力机制的计算链路为
通算融合的新方向:DeepSpeed-Domini 与 FLUX。 FlashAttention 证明了手写融合核函数的巨大潜力,但也暴露了传统融合方法的局限——它依赖专家手动编写高度优化的 CUDA 核函数,开发周期长、维护成本高。新一代通算融合框架追求更通用、更自动化的融合策略:
- 计算与通信融合:在分布式训练中,All-Reduce、All-Gather 等集合通信操作通常独立于计算执行,GPU 要么在计算、要么在等待通信完成。FLUX 等框架将通信操作与矩阵乘法在 CUDA 核函数级别进行融合——在 Tensor Core 执行矩阵乘法的同时,利用 GPU 的拷贝引擎(Copy Engine)和网络硬件并行执行数据传输。这种细粒度的计算-通信重叠可以将通信延迟几乎完全隐藏在计算之下。
- 编译器驱动的自动融合:PyTorch 2.0 的
torch.compile通过 Triton 后端自动识别和融合可合并的操作序列,降低了手写核函数的门槛。对于逐元素操作链、矩阵乘法后接激活函数等常见模式,编译器生成的融合核函数已接近手工优化的性能。
通算融合的本质是消除计算图中不必要的内存物化——无论是中间激活值的物化,还是通信等待期间的计算空闲。随着模型规模和集群规模的增大,通算融合的收益也越发显著。
10.8.4 多 Token 预测并行(MTP)
自回归解码的串行瓶颈。 标准的自回归语言模型每步只预测下一个 token:给定前缀
Multi-Token Prediction(MTP)的基本思想。 MTP 让模型在每个位置同时预测未来的多个 token,而非仅仅预测下一个 token。设模型在位置
其中
架构实现。 MTP 通常通过在共享的 Transformer 主干之上添加
- 共享主干 vs 独立头:所有预测头共享同一个 Transformer 主干的隐藏表示,仅在最后的预测层上进行分化。这样 MTP 的额外参数量和计算量很小,主要增量来自
个预测头的前向传播。 - 顺序依赖 vs 独立预测:DeepSeek-V3 采用的方案是让预测头之间存在顺序依赖——第
个预测头的输入不仅包含主干的隐藏状态,还包含第 个预测头的输出嵌入。这种设计使得后续 token 的预测可以利用前序预测的信息,提升多步预测的一致性。
训练阶段的收益。 MTP 的训练收益来自两个方面:
- 更丰富的监督信号:每个训练样本提供
倍的预测目标,等效于在相同数据量下获得更多的梯度信息。实验表明,MTP 可以在相同的训练 token 数下达到更低的验证损失,或在更少的 token 上达到相同的损失水平,即提升了样本效率。 - 推理加速的前置条件:MTP 训练出的模型天然具备多步预测能力,可以直接用于推理阶段的投机解码(Speculative Decoding)——模型自身的预测头充当草稿模型,无需额外训练一个独立的小模型。
推理阶段的投机解码。 在推理时,MTP 的多个预测头并行生成
训练中的并行化。 在训练阶段,
本节小结
本节讨论的四种训练加速技巧作用于不同的性能维度,彼此正交且可叠加使用:
| 技巧 | 核心机制 | 主要收益 | 代价 |
|---|---|---|---|
| 梯度检查点 | 以算换存:丢弃中间激活值,反向传播时重计算 | 激活显存从 | 约 33% 额外计算量 |
| FP8 混合精度 | E4M3/E5M2 双格式 + 逐张量缩放 | 矩阵乘法吞吐量 | 需要 Hopper+ 硬件;精度敏感操作仍需高精度 |
| 通算融合 | 合并算子消除中间内存物化;计算与通信重叠 | 减少 HBM 访问,隐藏通信延迟 | 实现复杂度高,部分场景依赖手写核函数 |
| MTP 并行 | 多预测头同时预测多个未来 token | 提升样本效率,支持无损投机解码加速推理 | 额外预测头的参数与计算开销 |
这些技巧与前面章节讨论的分布式并行策略构成互补关系:分布式并行解决的是"如何将工作分配到多个设备上",而本节的加速技巧解决的是"如何让每个设备上的工作跑得更快、用得更省"。在实际的大模型训练系统中,工程师通常会将这些技巧全部启用——在 3D/4D 并行的基础上叠加梯度检查点、FP8 精度、融合核函数和 MTP 训练目标——以在给定的硬件预算下最大化训练效率。