10.6 ZeRO 与 FSDP
前面章节讨论的张量并行(TP)、流水线并行(PP)和序列并行(SP)都属于模型并行范畴——它们通过切分模型的权重矩阵、层或序列维度,让每个 GPU 只承担模型计算的一部分。而在最基本的**数据并行(DDP)**中,每个 GPU 持有完整的模型副本,仅切分训练数据。DDP 实现简单、扩展性好,但存在一个根本性的问题:显存冗余。
以使用 Adam 优化器的混合精度训练为例,假设模型有
| 组成部分 | 精度 | 每参数字节数 |
|---|---|---|
| 模型参数(FP16) | 半精度 | 2 Bytes |
| 梯度(FP16) | 半精度 | 2 Bytes |
| Adam 一阶动量(FP32) | 全精度 | 4 Bytes |
| Adam 二阶动量(FP32) | 全精度 | 4 Bytes |
| FP32 主权重(Master Weights) | 全精度 | 4 Bytes |
每个参数的总显存开销为
更关键的是,在
10.6.1 ZeRO 的核心思想
ZeRO(Zero Redundancy Optimizer)由微软 DeepSpeed 团队提出,其核心思路非常直接:既然模型状态在每张 GPU 上都是重复的,那就把它们切分(Shard)到不同 GPU 上,每个 GPU 只存储
ZeRO 将模型状态分为三类——优化器状态、梯度、模型参数——并据此设计了三个递进的分片阶段(Stage),每个阶段在前一阶段的基础上分片更多的状态,换取更大的显存节省。
理解 ZeRO 需要两个关键的集合通信原语:
- Reduce-Scatter:对所有 GPU 上的数据先执行规约(如求和),再将结果切片分发,使得每张 GPU 只保留结果的
分片。 - All-Gather:将各 GPU 上的
分片收集拼接,使每张 GPU 都获得完整数据。
它们的组合关系是:
10.6.2 ZeRO Stage 1:优化器状态分片
在标准 DDP 中,每张 GPU 经过 All-Reduce 获得完整的全局平均梯度后,用本地的完整优化器状态更新完整的模型参数。ZeRO Stage 1 的改变是:每张 GPU 只维护
具体流程如下:
Step 1: Forward & Backward(各 GPU 独立计算,与 DDP 完全一致)
↓
Step 2: Reduce-Scatter 梯度
每张 GPU 得到自己负责的 1/N 梯度分片(已全局规约)
↓
Step 3: 局部优化器更新
每张 GPU 用本地 1/N 优化器状态 + 1/N 梯度分片
→ 更新本地 1/N 参数分片
↓
Step 4: All-Gather 参数
将各 GPU 更新后的参数分片拼装为完整参数
所有 GPU 重新持有一致的完整模型参数显存分析。 优化器状态从每张 GPU 存储
当
通信分析。 DDP 中的 All-Reduce 通信量为
10.6.3 ZeRO Stage 2:梯度分片
Stage 2 在 Stage 1 的基础上进一步分片梯度。既然每张 GPU 只负责更新
与 Stage 1 的区别体现在反向传播阶段的实现上:
反向传播中,每一层梯度计算完成后:
1. 立即对该层梯度执行 Reduce-Scatter
2. 每张 GPU 只保留自己负责的 1/N 梯度分片
3. 释放其余 (N-1)/N 的梯度内存
→ 整个反向传播结束后,每张 GPU 上只存在 1/N 的梯度
→ 随后的优化器更新和 All-Gather 与 Stage 1 一致显存分析。 梯度从
当
通信分析。 Stage 2 的梯度 Reduce-Scatter 与 Stage 1 完全相同(都是在反向传播后对梯度执行 Reduce-Scatter),All-Gather 也相同。因此 Stage 2 的通信量仍为
10.6.4 ZeRO Stage 3:参数分片
Stage 3 是最彻底的方案:连模型参数本身也进行分片,每张 GPU 在任何时刻默认只持有
前向传播第 l 层:
1. All-Gather:从所有 GPU 收集第 l 层完整参数
2. 执行第 l 层前向计算
3. 立即释放非本地的 (N-1)/N 参数
4. 进入第 l+1 层...
反向传播第 l 层:
1. All-Gather:再次收集第 l 层完整参数
2. 计算第 l 层梯度
3. 释放完整参数
4. Reduce-Scatter 梯度:每张 GPU 仅保留 1/N 梯度分片
5. 进入第 l-1 层...
参数更新:
每张 GPU 用本地 1/N 优化器状态 + 1/N 梯度 → 更新 1/N 参数
(无需 All-Gather,因为下一轮迭代的前向传播会按需收集)可以将 Stage 3 理解为一个滑动窗口:完整参数在 GPU 间像一个窗口一样逐层滑过,到哪一层就临时拼装哪一层的完整参数,计算完毕立即回收。这种"用时拼装、用后即弃"的策略,使得 GPU 的显存峰值被压缩到极低。
显存分析。 模型参数、梯度、优化器状态全部分片,总显存占用:
显存节省与 GPU 数量
通信分析。 与 Stage 1/2 相比,Stage 3 在前向和反向传播中各增加了一次 All-Gather(用于临时收集完整参数)。前向传播增加
10.6.5 三阶段对比
下表汇总了 ZeRO 三个阶段的关键指标(
| 分片内容 | 每 GPU 显存 | 通信量 | 显存节省倍数( | |
|---|---|---|---|---|
| DDP(基线) | 无 | 1x | ||
| ZeRO Stage 1 | 优化器状态 | ~4x | ||
| ZeRO Stage 2 | 优化器状态 + 梯度 | ~8x | ||
| ZeRO Stage 3 | 全部模型状态 | ~ |
表 10-2:ZeRO 三阶段与 DDP 的显存和通信开销对比。
可以清晰地看到 ZeRO 的递进逻辑:
- Stage 1 和 Stage 2 在通信量与 DDP 完全相同的前提下,分别实现了约 4 倍和 8 倍的显存节省。这是"免费午餐"级别的优化。
- Stage 3 以 1.5 倍通信量为代价,实现了与 GPU 数量成正比的线性显存缩减。当 GPU 数量较多时,这一代价被大规模的显存释放所抵消。
实践选择建议。
- Stage 1:适合显存略有不足的场景。由于对通信和计算流程的改动最小,几乎可以作为 DDP 的"免费升级"。
- Stage 2:大部分场景的推荐选项。在显存和速度之间取得了良好平衡,是 DeepSpeed 的默认推荐配置。
- Stage 3:训练超大模型(如千亿参数)的必选项。虽然通信开销增加,但它是在有限 GPU 资源下训练超大模型的唯一选择。
10.6.6 ZeRO-Offload 与 ZeRO-Infinity
当 GPU 显存仍然不够时,ZeRO 还提供了两种**卸载(Offload)**扩展:
ZeRO-Offload 将优化器状态和梯度计算卸载到 CPU 内存。在 GPU 上只执行前向和反向传播,梯度通过 PCIe 传回 CPU,由 CPU 完成参数更新后再将更新的参数传回 GPU。这使得在单张 GPU 上也能训练远超其显存容量的模型,代价是 CPU-GPU 之间的数据传输引入了额外延迟。
ZeRO-Infinity 进一步将卸载范围扩展到 NVMe SSD。利用高速固态硬盘作为"超大显存池",配合精心设计的预取机制(在 GPU 计算当前层时,CPU 从 SSD 预读下一层的数据),实现在极少量 GPU 上训练万亿级参数模型。
DeepSpeed 通过配置文件 ds_config.json 控制 ZeRO 的行为:
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}stage 字段设为 0/1/2/3 分别对应关闭 ZeRO 和三个阶段;offload_optimizer 和 offload_param 控制是否将优化器状态或模型参数卸载到 CPU;overlap_comm 启用通信与计算的重叠以隐藏通信延迟。
10.6.7 FSDP:PyTorch 原生的全分片数据并行
FSDP(Fully Sharded Data Parallel) 是 PyTorch 官方推出的大模型训练显存优化方案。其核心思想与 ZeRO Stage 3 高度一致——对模型参数、梯度和优化器状态进行完全分片——但作为 PyTorch 的原生模块,它在 API 设计、生态兼容性和工程实现上具有独特优势。
FSDP 的工作流程 与 ZeRO Stage 3 在逻辑上完全对应:
初始化:模型参数在 N 张 GPU 间完全分片,每张 GPU 仅持有 1/N
┌─ 前向传播 ──────────────────────────────────────────────┐
│ 对每个 FSDP 单元(Unit): │
│ ① All-Gather:收集完整参数 │
│ ② 执行前向计算 │
│ ③ 释放非本地参数分片(仅保留本地 1/N) │
└──────────────────────────────────────────────────────────┘
┌─ 反向传播 ──────────────────────────────────────────────┐
│ 对每个 FSDP 单元(逆序): │
│ ① All-Gather:再次收集完整参数 │
│ ② 执行反向计算,得到完整梯度 │
│ ③ 释放完整参数 │
│ ④ Reduce-Scatter 梯度:每张 GPU 仅保留 1/N 梯度分片 │
└──────────────────────────────────────────────────────────┘
┌─ 参数更新 ──────────────────────────────────────────────┐
│ 每张 GPU 独立使用本地的 1/N 优化器状态 + 1/N 梯度 │
│ → 更新本地 1/N 参数 │
└──────────────────────────────────────────────────────────┘虽然原理与 ZeRO Stage 3 相同,FSDP 在工程层面有几个重要的差异化设计:
FSDP 单元(FSDP Unit)与 Auto-Wrapping。 FSDP 允许用户通过"包装策略(Wrapping Policy)"控制分片的粒度。模型被划分为若干个 FSDP 单元,每个单元内部的参数作为一个整体进行分片和 All-Gather。用户可以按层类型(如每个 Transformer Block 为一个单元)或参数量阈值(如参数量 > 100M 的模块独立成为一个单元)来自动划分。
这种分层包装的设计直接影响显存峰值:
- 粒度太粗(整个模型为一个 FSDP 单元):All-Gather 一次性收集全部参数,峰值显存等同于完整模型,失去了分片的意义。
- 粒度太细(每个线性层都是独立单元):All-Gather 的调用过于频繁,通信开销剧增。
- 合理粒度(每个 Transformer Block 为一个单元):每次 All-Gather 只收集一个 Block 的参数,显存峰值仅为"一个 Block 的完整参数 + 其余所有 Block 的
分片",在显存节省和通信效率之间取得平衡。
Sharding Strategy。 PyTorch FSDP 提供了多种分片策略,对应 ZeRO 的不同阶段:
| FSDP ShardingStrategy | 对应 ZeRO 阶段 | 说明 |
|---|---|---|
FULL_SHARD | Stage 3 | 参数、梯度、优化器状态全部分片 |
SHARD_GRAD_OP | Stage 2 | 梯度和优化器状态分片,参数前向后不释放 |
NO_SHARD | DDP | 不分片,退化为标准 DDP |
FSDP 代码示例。 以下展示使用 PyTorch FSDP 训练的基本代码框架:
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
# 初始化分布式环境
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# 构建模型
model = MyTransformerModel()
# 定义自动包装策略:每个 TransformerBlock 作为一个 FSDP 单元
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}, # 指定要包装的层类型
)
# 混合精度配置
mixed_precision = MixedPrecision(
param_dtype=torch.float16, # 前向/反向使用 FP16
reduce_dtype=torch.float16, # 梯度通信使用 FP16
buffer_dtype=torch.float16,
)
# 用 FSDP 包装模型
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 对应 ZeRO Stage 3
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision,
device_id=local_rank,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 训练循环
for batch in dataloader:
inputs, labels = batch
inputs = inputs.cuda(local_rank)
labels = labels.cuda(local_rank)
outputs = model(inputs) # FSDP 自动处理 All-Gather / 释放
loss = loss_fn(outputs, labels)
loss.backward() # FSDP 自动处理 All-Gather / Reduce-Scatter
optimizer.step() # 每张 GPU 更新本地 1/N 参数
optimizer.zero_grad()从代码可以看到,FSDP 的用户接口与标准的 DDP 训练几乎完全一致。核心差异仅在于将 DistributedDataParallel 替换为 FullyShardedDataParallel,并指定分片策略和包装策略。所有 All-Gather、Reduce-Scatter 和参数释放操作都被封装在 FSDP 的前向/反向钩子中,对用户透明。
10.6.8 ZeRO 与 FSDP 的对比
ZeRO(DeepSpeed)和 FSDP(PyTorch)在核心原理上是等价的,但在工程定位和适用场景上有所区分:
| 维度 | DeepSpeed ZeRO | PyTorch FSDP |
|---|---|---|
| 实现方式 | 独立库,通过训练引擎包装 | PyTorch 原生模块 |
| 分片粒度 | Stage 1/2/3 三级可选 | ShardingStrategy 枚举,对应 Stage 2/3/DDP |
| Offload 能力 | 成熟的 CPU/NVMe Offload | 支持 CPU Offload(FSDP2 改进中) |
| 3D 并行集成 | 原生支持 3D 并行(TP + PP + ZeRO) | 需与 DeviceMesh/DTensor 配合 |
| 生态兼容性 | 需适配 DeepSpeed 引擎 API | 与 PyTorch AMP、编译器、Profiler 无缝集成 |
| 配置方式 | JSON 配置文件 | Python API |
| 社区与维护 | 微软维护 | PyTorch 核心团队维护 |
选择建议。
- 如果已使用 DeepSpeed 生态(如 Megatron-DeepSpeed)进行 3D 并行训练,ZeRO 是自然选择,其 Offload 能力和配置灵活性更成熟。
- 如果希望保持纯 PyTorch 技术栈、减少外部依赖,或需要与 PyTorch 的编译器(
torch.compile)、分布式张量(DTensor)等新特性深度配合,FSDP 是更优方案。 - 在大多数场景下,两者的训练效率和显存节省效果非常接近。选择更多取决于工程栈的偏好而非性能差异。
10.6.9 ZeRO/FSDP 在混合并行中的角色
在第 10.5 节末尾提到的 4D 并行架构中,ZeRO/FSDP 扮演着数据并行维度的显存优化器角色。典型的混合并行配置如下:
┌───────────────────────────────────────┐
│ 4D 并行配置 │
│ │
节点内 │ TP(张量并行):节点内 NVLink 高带宽 │
─────── │ → 切分单层权重矩阵 │
│ │
跨节点 │ PP(流水线并行):跨节点 InfiniBand │
─────── │ → 切分模型层 │
│ │
DP 组内 │ ZeRO / FSDP:在 DP 副本之间 │
─────── │ → 分片优化器状态 / 梯度 / 参数 │
│ │
MoE 模型 │ EP(专家并行):跨 GPU 分布专家 │
─────── │ → All-to-All 路由 Token │
└───────────────────────────────────────┘在这种配置中,同一数据并行组(DP Group)内的 GPU 处理不同的数据分片、持有相同的模型参数副本。传统 DDP 要求每张 GPU 完整存储模型状态,而 ZeRO/FSDP 在 DP 组内部进行分片,使得即使在 TP + PP 已经切分模型的基础上,DP 维度也不再有显存冗余。
实际应用中,4D 并行通常使用 ZeRO Stage 1 或 Stage 2(而非 Stage 3),因为 TP 和 PP 已经分担了大部分模型切分工作,DP 维度只需分片优化器状态和梯度即可。Stage 3 的额外 All-Gather 通信在与 TP/PP 叠加时会产生较大的通信压力。
本节小结
ZeRO 和 FSDP 从显存状态切分的角度解决了数据并行中的冗余问题,是当前大模型训练基础设施的核心组件:
- ZeRO Stage 1 分片优化器状态,在零额外通信的前提下节省约 4 倍显存。
- ZeRO Stage 2 进一步分片梯度,通信量不变,显存节省约 8 倍。
- ZeRO Stage 3 / FSDP
FULL_SHARD对全部模型状态进行完全分片,显存节省与 GPU 数量成正比,代价是 1.5 倍于 DDP 的通信量。 - ZeRO-Offload / ZeRO-Infinity 通过 CPU/NVMe 卸载进一步突破 GPU 显存上限。
与模型并行(TP、PP)不同,ZeRO/FSDP 保留了数据并行易于使用的核心优势——用户无需手动切分模型结构,只需指定分片策略,框架即可自动处理所有通信和显存管理。正是这种"对用户透明的显存优化"特性,使得 ZeRO/FSDP 成为了从学术实验到工业训练的通用标配技术。