Skip to content

10.8 训练加速技巧

前面章节讨论了数据并行、张量并行、流水线并行等分布式并行策略,它们通过将计算分散到多个设备上来突破单卡的显存和算力限制。然而,即使在单卡层面,仍然存在大量"隐藏的性能税":冗余的激活值显存占用、低精度硬件能力的闲置、算子间频繁的内存读写、以及自回归解码中逐 token 预测的串行瓶颈。本节介绍四种正交于分布式并行的训练加速技巧——梯度检查点、FP8 混合精度训练、通算融合和多 Token 预测并行——它们分别从显存、精度、访存和解码四个维度挖掘加速空间。


10.8.1 梯度检查点(激活值重计算)

问题背景。 标准的反向传播算法要求在前向传播时保存所有中间层的激活值,以便反向传播计算梯度时使用。对于一个 L 层的网络,若每层激活值的显存开销为 a,则总激活显存为 O(La)。当模型达到数十亿参数、序列长度增至数万 token 时,激活值的显存占用往往超过模型参数本身,成为限制训练批大小和模型规模的主要瓶颈。

核心思想。 梯度检查点(Gradient Checkpointing),也称激活检查点(Activation Checkpointing)或激活值重计算(Activation Recomputation),是一种经典的以算换存技术。其基本策略是:前向传播时不再保存所有中间层的激活值,而是只保存若干精心选择的"检查点"层的激活值,丢弃其余层的激活值;反向传播需要某个被丢弃的激活值时,从最近的检查点重新执行一段前向传播来即时重算。

以算换存的定量分析。 设网络共有 L 层,在其中均匀地设置 k 个检查点(即每隔 L/k 层保存一次)。此时:

  • 显存开销:只需存储 k 个检查点的激活值,加上反向传播时当前段内重计算产生的临时激活值(最多 L/k 个层),总显存为 O(k+L/k)a
  • 最优检查点数:对 k+L/k 求最小值,令导数为零得 k=L,此时显存开销为 O(2L)a=O(La)
  • 计算代价:每个检查点段内的前向传播被额外执行一次。原始前向传播的总计算量为 F,反向传播约为 2F(每层需要计算关于输入和参数的两组梯度),加上重计算的额外前向传播 F,总计算量从 3F 增加到约 4F,即增加约 33% 的计算量

这意味着:一个 L=100 层的网络,从存储全部 100 层激活值降低到只存储约 2100=20 层的激活值,显存减少约 5×,仅以 33% 的额外计算为代价。在 GPU 算力过剩但显存紧张的现实场景下,这是一笔极为划算的交易。

实践要点。 现代训练框架(如 PyTorch 的 torch.utils.checkpoint)已将梯度检查点封装为易用的 API。实际使用时,检查点的粒度通常以 Transformer 层为单位——每隔若干层设置一个检查点。更精细的策略(如选择性重计算)会根据每层激活值的大小和重计算成本做差异化决策:对于占用显存大但重计算代价低的操作(如 LayerNorm、Dropout)优先丢弃,对于计算昂贵的操作(如矩阵乘法的输出)优先保留。

梯度检查点不仅能节省显存,在某些场景下甚至能提升端到端速度。当某个操作是内存密集型(memory-bound) 时——GPU 的计算单元大量时间在等待数据从 HBM 中读取——重新计算该操作可能比从 HBM 读取已保存的激活值更快。FlashAttention 的反向传播就利用了这一原理:它不存储 O(N2) 的注意力矩阵,而是在反向传播时从 Q、K、V 和少量统计量即时重计算,既省下了巨量显存,又因减少了 HBM 访问而加快了速度。


10.8.2 FP8 混合精度训练

从 BF16 到 FP8 的精度演进。 混合精度训练的核心思想是用低精度格式执行大部分计算和存储,用高精度格式维护关键状态(如主权重和梯度累加器)。BF16 混合精度已经成为大模型训练的事实标准:它保留了与 FP32 相同的 8 位指数,拥有相同的动态范围,有效避免了 FP16 容易发生的上溢和下溢问题,同时将内存占用和带宽需求减半。

FP8 将精度进一步压缩到 8 位,理论上可以在 BF16 的基础上再将内存和带宽需求减半,并且在 NVIDIA Hopper(H100)及后续架构上,Tensor Core 的 FP8 吞吐量是 BF16 的 2×。然而,8 位浮点数的表示空间极为有限,如何在如此有限的位宽中平衡动态范围与精度,是 FP8 训练面临的核心挑战。

E4M3 与 E5M2 双格式策略。 FP8 标准定义了两种互补的编码格式:

格式符号位指数位尾数位动态范围精度适用场景
E4M3143较小(±448)较高前向传播的权重与激活值
E5M2152较大(±57344)较低反向传播的梯度

这一双格式策略的设计逻辑是:

  • 前向传播使用 E4M3。前向传播中的权重和激活值经过训练后数值分布相对集中,对精度的需求高于对动态范围的需求。E4M3 的 3 位尾数提供了 FP8 中最高的精度,足以维持前向计算的准确性。
  • 反向传播使用 E5M2。梯度的数值分布范围极大——从接近零的小梯度到偶尔出现的梯度尖峰——对动态范围的需求远高于对精度的需求。E5M2 的 5 位指数提供了与 FP16 相同的动态范围,可以有效避免梯度的上溢和下溢。

在这两种格式之下,Tensor Core 内部的矩阵乘法累加器仍然使用 FP32,确保累加过程中不发生精度损失。整体计算链路为:FP8 输入 → FP8 乘法 → FP32 累加 → FP8/BF16 输出。

逐张量缩放(Per-Tensor Scaling)。 由于 FP8 的表示范围极其有限,直接将 BF16 数值截断到 FP8 会导致大量信息丢失。因此,FP8 训练必须配合缩放因子(Scaling Factor) 使用:在将张量转换为 FP8 之前,先除以一个缩放因子将数值映射到 FP8 的有效表示范围内;在使用 FP8 结果时再乘回缩放因子恢复原始量级。

缩放因子的选择通常基于延迟缩放(Delayed Scaling) 策略:使用前一个训练步中该张量的最大绝对值来计算当前步的缩放因子。这避免了在当前步额外遍历整个张量来统计最大值的开销,代价是缩放因子滞后一步,但实践中对训练稳定性的影响可以忽略。

实际收益与局限。 在 H100 及更新的架构上,FP8 训练可以将矩阵乘法的吞吐量提升至 BF16 的近 2×,同时将激活值的显存占用进一步压缩。DeepSeek-V3 的训练实践表明,FP8 混合精度训练在千亿级模型上可以达到与 BF16 几乎相同的模型质量,同时显著降低训练成本。需要注意的是,FP8 的优势主要体现在矩阵乘法密集的计算中(如线性层);对于 LayerNorm、Softmax 等对精度敏感的操作,仍需保持 BF16 或 FP32 精度。


10.8.3 通算融合

算子融合的动机。 在 PyTorch 等框架的默认执行模式下,每个数学操作(加、乘、激活函数、归一化等)都会启动一个独立的 CUDA 核函数。每个核函数都需要从 HBM 读取输入、将结果写回 HBM,供下一个核函数读取。对于逐元素操作(如 ReLU、残差加法)和归约操作(如 LayerNorm),计算量极小但内存读写量大——它们是典型的内存密集型操作。此时,GPU 的计算单元大量空闲,性能瓶颈完全在内存带宽上。

算子融合(Operator Fusion/Kernel Fusion) 通过将多个连续操作合并为一个 CUDA 核函数来消除中间内存读写。融合后的核函数只需一次从 HBM 读取输入,所有中间结果保留在 SM 的寄存器或共享内存(SRAM)中完成计算,最终结果一次写回 HBM。这样做有两个直接收益:(1)大幅减少 HBM 访问次数,提升有效内存带宽利用率;(2)消除核函数启动开销——每次启动 CUDA 核函数都有固定的驱动程序开销,融合后启动次数减少。

FlashAttention:通算融合的标杆案例。 标准注意力机制的计算链路为 S=QK/dkP=softmax(S)O=PV,朴素实现需要物化 N×N 的注意力矩阵 SP,对 HBM 进行多次读写。FlashAttention 将整个注意力计算融合为一个核函数,配合分块(Tiling)技术将 Q、K、V 按块加载到 SRAM 中,利用在线 Softmax 算法逐块计算并累积结果,全程避免物化完整的 N×N 矩阵。其效果是:内存占用从 O(N2) 降至 O(N),HBM 访问量大幅减少,注意力计算从内存密集型转变为计算密集型,端到端速度提升数倍。

通算融合的新方向:DeepSpeed-Domini 与 FLUX。 FlashAttention 证明了手写融合核函数的巨大潜力,但也暴露了传统融合方法的局限——它依赖专家手动编写高度优化的 CUDA 核函数,开发周期长、维护成本高。新一代通算融合框架追求更通用、更自动化的融合策略:

  • 计算与通信融合:在分布式训练中,All-Reduce、All-Gather 等集合通信操作通常独立于计算执行,GPU 要么在计算、要么在等待通信完成。FLUX 等框架将通信操作与矩阵乘法在 CUDA 核函数级别进行融合——在 Tensor Core 执行矩阵乘法的同时,利用 GPU 的拷贝引擎(Copy Engine)和网络硬件并行执行数据传输。这种细粒度的计算-通信重叠可以将通信延迟几乎完全隐藏在计算之下。
  • 编译器驱动的自动融合:PyTorch 2.0 的 torch.compile 通过 Triton 后端自动识别和融合可合并的操作序列,降低了手写核函数的门槛。对于逐元素操作链、矩阵乘法后接激活函数等常见模式,编译器生成的融合核函数已接近手工优化的性能。

通算融合的本质是消除计算图中不必要的内存物化——无论是中间激活值的物化,还是通信等待期间的计算空闲。随着模型规模和集群规模的增大,通算融合的收益也越发显著。


10.8.4 多 Token 预测并行(MTP)

自回归解码的串行瓶颈。 标准的自回归语言模型每步只预测下一个 token:给定前缀 x1,,xt,模型输出 xt+1 的概率分布。推理时,每个 token 必须等前一个 token 生成后才能开始预测,形成严格的串行依赖链。这使得推理延迟与生成长度成正比,成为大模型部署的核心瓶颈。

Multi-Token Prediction(MTP)的基本思想。 MTP 让模型在每个位置同时预测未来的多个 token,而非仅仅预测下一个 token。设模型在位置 t 同时预测 xt+1,xt+2,,xt+k,则训练目标从标准的 next-token loss 扩展为:

LMTP=i=1kLCE(xt+ix1,,xt)

其中 k 为预测窗口大小,LCE 为交叉熵损失。

架构实现。 MTP 通常通过在共享的 Transformer 主干之上添加 k 个独立的预测头(Prediction Head)来实现。每个预测头是一个轻量级模块(通常包含一个 Transformer 层和一个输出投影层),负责预测对应位置偏移的 token。关键设计决策包括:

  • 共享主干 vs 独立头:所有预测头共享同一个 Transformer 主干的隐藏表示,仅在最后的预测层上进行分化。这样 MTP 的额外参数量和计算量很小,主要增量来自 k 个预测头的前向传播。
  • 顺序依赖 vs 独立预测:DeepSeek-V3 采用的方案是让预测头之间存在顺序依赖——第 i 个预测头的输入不仅包含主干的隐藏状态,还包含第 i1 个预测头的输出嵌入。这种设计使得后续 token 的预测可以利用前序预测的信息,提升多步预测的一致性。

训练阶段的收益。 MTP 的训练收益来自两个方面:

  1. 更丰富的监督信号:每个训练样本提供 k 倍的预测目标,等效于在相同数据量下获得更多的梯度信息。实验表明,MTP 可以在相同的训练 token 数下达到更低的验证损失,或在更少的 token 上达到相同的损失水平,即提升了样本效率
  2. 推理加速的前置条件:MTP 训练出的模型天然具备多步预测能力,可以直接用于推理阶段的投机解码(Speculative Decoding)——模型自身的预测头充当草稿模型,无需额外训练一个独立的小模型。

推理阶段的投机解码。 在推理时,MTP 的多个预测头并行生成 k 个候选 token,然后通过主模型的一次前向传播验证这些候选 token 的正确性。如果前 j 个候选 token 全部通过验证,则一步解码 j 个 token,解码速度提升约 j×。即使部分候选 token 被拒绝,回退机制也保证生成质量与标准自回归解码完全一致——MTP 投机解码是无损的加速方法。

训练中的并行化。 在训练阶段,k 个预测头的前向和反向传播可以并行执行(它们共享主干的隐藏表示,且各自的损失独立计算),这使得 MTP 的训练开销远低于 k 倍。在 DeepSeek-V3 的实现中,MTP 仅增加了约 4.3% 的训练计算量(使用 k=1 个额外预测头),但在推理阶段通过投机解码实现了 1.8 倍的解码速度提升。


本节小结

本节讨论的四种训练加速技巧作用于不同的性能维度,彼此正交且可叠加使用:

技巧核心机制主要收益代价
梯度检查点以算换存:丢弃中间激活值,反向传播时重计算激活显存从 O(L) 降至 O(L)约 33% 额外计算量
FP8 混合精度E4M3/E5M2 双格式 + 逐张量缩放矩阵乘法吞吐量 2×,显存进一步压缩需要 Hopper+ 硬件;精度敏感操作仍需高精度
通算融合合并算子消除中间内存物化;计算与通信重叠减少 HBM 访问,隐藏通信延迟实现复杂度高,部分场景依赖手写核函数
MTP 并行多预测头同时预测多个未来 token提升样本效率,支持无损投机解码加速推理额外预测头的参数与计算开销

这些技巧与前面章节讨论的分布式并行策略构成互补关系:分布式并行解决的是"如何将工作分配到多个设备上",而本节的加速技巧解决的是"如何让每个设备上的工作跑得更快、用得更省"。在实际的大模型训练系统中,工程师通常会将这些技巧全部启用——在 3D/4D 并行的基础上叠加梯度检查点、FP8 精度、融合核函数和 MTP 训练目标——以在给定的硬件预算下最大化训练效率。