Skip to content

10.6 ZeRO 与 FSDP

前面章节讨论的张量并行(TP)、流水线并行(PP)和序列并行(SP)都属于模型并行范畴——它们通过切分模型的权重矩阵、层或序列维度,让每个 GPU 只承担模型计算的一部分。而在最基本的**数据并行(DDP)**中,每个 GPU 持有完整的模型副本,仅切分训练数据。DDP 实现简单、扩展性好,但存在一个根本性的问题:显存冗余

以使用 Adam 优化器的混合精度训练为例,假设模型有 Ψ 个参数,每个 GPU 需要存储的模型状态(Model States)包括:

组成部分精度每参数字节数
模型参数(FP16)半精度2 Bytes
梯度(FP16)半精度2 Bytes
Adam 一阶动量(FP32)全精度4 Bytes
Adam 二阶动量(FP32)全精度4 Bytes
FP32 主权重(Master Weights)全精度4 Bytes

每个参数的总显存开销为 2+2+4+4+4=16 Bytes。其中模型参数和梯度合计 4Ψ Bytes,优化器状态合计 12Ψ Bytes。对于一个 7B 参数的模型,仅模型状态就至少占用 7×109×16=112 GB 显存——远超单张 80 GB 的 A100/H100 显卡容量。

更关键的是,在 N 卡的 DDP 训练中,这 16Ψ Bytes 在每张 GPU 上完全重复存储。N 越大,冗余越严重。ZeRO 和 FSDP 正是针对这一冗余问题提出的系统性解决方案。


10.6.1 ZeRO 的核心思想

ZeRO(Zero Redundancy Optimizer)由微软 DeepSpeed 团队提出,其核心思路非常直接:既然模型状态在每张 GPU 上都是重复的,那就把它们切分(Shard)到不同 GPU 上,每个 GPU 只存储 1/N 的份额。需要完整数据时,通过集合通信临时拼装;用完后立即释放。

ZeRO 将模型状态分为三类——优化器状态、梯度、模型参数——并据此设计了三个递进的分片阶段(Stage),每个阶段在前一阶段的基础上分片更多的状态,换取更大的显存节省。

理解 ZeRO 需要两个关键的集合通信原语:

  • Reduce-Scatter:对所有 GPU 上的数据先执行规约(如求和),再将结果切片分发,使得每张 GPU 只保留结果的 1/N 分片。
  • All-Gather:将各 GPU 上的 1/N 分片收集拼接,使每张 GPU 都获得完整数据。

它们的组合关系是:All-ReduceReduce-Scatter+All-Gather。ZeRO 正是利用这一等价关系,将原本 DDP 中不可分割的 All-Reduce 拆解为两步,在中间插入分片存储与局部更新,从而在几乎不增加通信量的前提下大幅降低显存占用。


10.6.2 ZeRO Stage 1:优化器状态分片

在标准 DDP 中,每张 GPU 经过 All-Reduce 获得完整的全局平均梯度后,用本地的完整优化器状态更新完整的模型参数。ZeRO Stage 1 的改变是:每张 GPU 只维护 1/N 的优化器状态,只负责更新对应的 1/N 参数分片

具体流程如下:

Step 1: Forward & Backward(各 GPU 独立计算,与 DDP 完全一致)

Step 2: Reduce-Scatter 梯度
        每张 GPU 得到自己负责的 1/N 梯度分片(已全局规约)

Step 3: 局部优化器更新
        每张 GPU 用本地 1/N 优化器状态 + 1/N 梯度分片
        → 更新本地 1/N 参数分片

Step 4: All-Gather 参数
        将各 GPU 更新后的参数分片拼装为完整参数
        所有 GPU 重新持有一致的完整模型参数

显存分析。 优化器状态从每张 GPU 存储 12Ψ 降为 12Ψ/N。模型参数(2Ψ)和梯度(2Ψ)仍在每张 GPU 上完整存储。总显存占用:

MStage1=4Ψ+12ΨN

N 足够大时,优化器状态的开销趋近于零,总显存约为 4Ψ(仅参数 + 梯度),相比 DDP 的 16Ψ 节省约 4 倍

通信分析。 DDP 中的 All-Reduce 通信量为 2Ψ(Reduce-Scatter Ψ + All-Gather Ψ)。ZeRO Stage 1 的通信也是 Reduce-Scatter(梯度)+ All-Gather(参数),总量同样为 2Ψ。因此 Stage 1 在通信量上与 DDP 完全相同,这也是 ZeRO 论文称之为"零开销"的原因。


10.6.3 ZeRO Stage 2:梯度分片

Stage 2 在 Stage 1 的基础上进一步分片梯度。既然每张 GPU 只负责更新 1/N 的参数,那么它也只需要保留对应那部分参数的梯度,其余梯度在 Reduce-Scatter 完成后可以立即释放。

与 Stage 1 的区别体现在反向传播阶段的实现上:

反向传播中,每一层梯度计算完成后:
  1. 立即对该层梯度执行 Reduce-Scatter
  2. 每张 GPU 只保留自己负责的 1/N 梯度分片
  3. 释放其余 (N-1)/N 的梯度内存

→ 整个反向传播结束后,每张 GPU 上只存在 1/N 的梯度
→ 随后的优化器更新和 All-Gather 与 Stage 1 一致

显存分析。 梯度从 2Ψ 降为 2Ψ/N,总显存占用:

MStage2=2Ψ+2Ψ+12ΨN=2Ψ+14ΨN

N 足够大时,总显存约为 2Ψ(仅模型参数),相比 DDP 节省约 8 倍

通信分析。 Stage 2 的梯度 Reduce-Scatter 与 Stage 1 完全相同(都是在反向传播后对梯度执行 Reduce-Scatter),All-Gather 也相同。因此 Stage 2 的通信量仍为 2Ψ,与 DDP 一致。


10.6.4 ZeRO Stage 3:参数分片

Stage 3 是最彻底的方案:连模型参数本身也进行分片,每张 GPU 在任何时刻默认只持有 1/N 的参数。需要某一层参数时动态通过 All-Gather 拼装,用完后立即丢弃。

前向传播第 l 层:
  1. All-Gather:从所有 GPU 收集第 l 层完整参数
  2. 执行第 l 层前向计算
  3. 立即释放非本地的 (N-1)/N 参数
  4. 进入第 l+1 层...

反向传播第 l 层:
  1. All-Gather:再次收集第 l 层完整参数
  2. 计算第 l 层梯度
  3. 释放完整参数
  4. Reduce-Scatter 梯度:每张 GPU 仅保留 1/N 梯度分片
  5. 进入第 l-1 层...

参数更新:
  每张 GPU 用本地 1/N 优化器状态 + 1/N 梯度 → 更新 1/N 参数
  (无需 All-Gather,因为下一轮迭代的前向传播会按需收集)

可以将 Stage 3 理解为一个滑动窗口:完整参数在 GPU 间像一个窗口一样逐层滑过,到哪一层就临时拼装哪一层的完整参数,计算完毕立即回收。这种"用时拼装、用后即弃"的策略,使得 GPU 的显存峰值被压缩到极低。

显存分析。 模型参数、梯度、优化器状态全部分片,总显存占用:

MStage3=2Ψ+2Ψ+12ΨN=16ΨN

显存节省与 GPU 数量 N 成正比。N=64 时,每张 GPU 仅需存储原始显存的 1/64

通信分析。 与 Stage 1/2 相比,Stage 3 在前向和反向传播中各增加了一次 All-Gather(用于临时收集完整参数)。前向传播增加 Ψ,反向传播增加 Ψ,加上原有的梯度 Reduce-Scatter(Ψ),总通信量为 3Ψ——是 DDP 的 1.5 倍。这是 Stage 3 用通信换显存的代价。在实践中,可以通过预取(Prefetch)和计算-通信重叠来隐藏大部分额外延迟。


10.6.5 三阶段对比

下表汇总了 ZeRO 三个阶段的关键指标(N 为数据并行度,Ψ 为模型参数量):

分片内容每 GPU 显存通信量显存节省倍数(N
DDP(基线)16Ψ2Ψ1x
ZeRO Stage 1优化器状态4Ψ+12Ψ/N2Ψ~4x
ZeRO Stage 2优化器状态 + 梯度2Ψ+14Ψ/N2Ψ~8x
ZeRO Stage 3全部模型状态16Ψ/N3Ψ~Nx

表 10-2:ZeRO 三阶段与 DDP 的显存和通信开销对比。

可以清晰地看到 ZeRO 的递进逻辑:

  • Stage 1 和 Stage 2 在通信量与 DDP 完全相同的前提下,分别实现了约 4 倍和 8 倍的显存节省。这是"免费午餐"级别的优化。
  • Stage 3 以 1.5 倍通信量为代价,实现了与 GPU 数量成正比的线性显存缩减。当 GPU 数量较多时,这一代价被大规模的显存释放所抵消。

实践选择建议。

  • Stage 1:适合显存略有不足的场景。由于对通信和计算流程的改动最小,几乎可以作为 DDP 的"免费升级"。
  • Stage 2:大部分场景的推荐选项。在显存和速度之间取得了良好平衡,是 DeepSpeed 的默认推荐配置。
  • Stage 3:训练超大模型(如千亿参数)的必选项。虽然通信开销增加,但它是在有限 GPU 资源下训练超大模型的唯一选择。

10.6.6 ZeRO-Offload 与 ZeRO-Infinity

当 GPU 显存仍然不够时,ZeRO 还提供了两种**卸载(Offload)**扩展:

ZeRO-Offload 将优化器状态和梯度计算卸载到 CPU 内存。在 GPU 上只执行前向和反向传播,梯度通过 PCIe 传回 CPU,由 CPU 完成参数更新后再将更新的参数传回 GPU。这使得在单张 GPU 上也能训练远超其显存容量的模型,代价是 CPU-GPU 之间的数据传输引入了额外延迟。

ZeRO-Infinity 进一步将卸载范围扩展到 NVMe SSD。利用高速固态硬盘作为"超大显存池",配合精心设计的预取机制(在 GPU 计算当前层时,CPU 从 SSD 预读下一层的数据),实现在极少量 GPU 上训练万亿级参数模型。

DeepSpeed 通过配置文件 ds_config.json 控制 ZeRO 的行为:

json
{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto"
  }
}

stage 字段设为 0/1/2/3 分别对应关闭 ZeRO 和三个阶段;offload_optimizeroffload_param 控制是否将优化器状态或模型参数卸载到 CPU;overlap_comm 启用通信与计算的重叠以隐藏通信延迟。


10.6.7 FSDP:PyTorch 原生的全分片数据并行

FSDP(Fully Sharded Data Parallel) 是 PyTorch 官方推出的大模型训练显存优化方案。其核心思想与 ZeRO Stage 3 高度一致——对模型参数、梯度和优化器状态进行完全分片——但作为 PyTorch 的原生模块,它在 API 设计、生态兼容性和工程实现上具有独特优势。

FSDP 的工作流程 与 ZeRO Stage 3 在逻辑上完全对应:

初始化:模型参数在 N 张 GPU 间完全分片,每张 GPU 仅持有 1/N

┌─ 前向传播 ──────────────────────────────────────────────┐
│  对每个 FSDP 单元(Unit):                                │
│    ① All-Gather:收集完整参数                             │
│    ② 执行前向计算                                        │
│    ③ 释放非本地参数分片(仅保留本地 1/N)                    │
└──────────────────────────────────────────────────────────┘

┌─ 反向传播 ──────────────────────────────────────────────┐
│  对每个 FSDP 单元(逆序):                                │
│    ① All-Gather:再次收集完整参数                          │
│    ② 执行反向计算,得到完整梯度                             │
│    ③ 释放完整参数                                        │
│    ④ Reduce-Scatter 梯度:每张 GPU 仅保留 1/N 梯度分片     │
└──────────────────────────────────────────────────────────┘

┌─ 参数更新 ──────────────────────────────────────────────┐
│  每张 GPU 独立使用本地的 1/N 优化器状态 + 1/N 梯度           │
│  → 更新本地 1/N 参数                                     │
└──────────────────────────────────────────────────────────┘

虽然原理与 ZeRO Stage 3 相同,FSDP 在工程层面有几个重要的差异化设计:

FSDP 单元(FSDP Unit)与 Auto-Wrapping。 FSDP 允许用户通过"包装策略(Wrapping Policy)"控制分片的粒度。模型被划分为若干个 FSDP 单元,每个单元内部的参数作为一个整体进行分片和 All-Gather。用户可以按层类型(如每个 Transformer Block 为一个单元)或参数量阈值(如参数量 > 100M 的模块独立成为一个单元)来自动划分。

这种分层包装的设计直接影响显存峰值:

  • 粒度太粗(整个模型为一个 FSDP 单元):All-Gather 一次性收集全部参数,峰值显存等同于完整模型,失去了分片的意义。
  • 粒度太细(每个线性层都是独立单元):All-Gather 的调用过于频繁,通信开销剧增。
  • 合理粒度(每个 Transformer Block 为一个单元):每次 All-Gather 只收集一个 Block 的参数,显存峰值仅为"一个 Block 的完整参数 + 其余所有 Block 的 1/N 分片",在显存节省和通信效率之间取得平衡。

Sharding Strategy。 PyTorch FSDP 提供了多种分片策略,对应 ZeRO 的不同阶段:

FSDP ShardingStrategy对应 ZeRO 阶段说明
FULL_SHARDStage 3参数、梯度、优化器状态全部分片
SHARD_GRAD_OPStage 2梯度和优化器状态分片,参数前向后不释放
NO_SHARDDDP不分片,退化为标准 DDP

FSDP 代码示例。 以下展示使用 PyTorch FSDP 训练的基本代码框架:

python
# 教学示例:展示核心逻辑,省略了部分 import 和辅助函数定义
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

# 初始化分布式环境
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# 构建模型
model = MyTransformerModel()

# 定义自动包装策略:每个 TransformerBlock 作为一个 FSDP 单元
auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock},  # 指定要包装的层类型
)

# 混合精度配置
mixed_precision = MixedPrecision(
    param_dtype=torch.float16,      # 前向/反向使用 FP16
    reduce_dtype=torch.float16,     # 梯度通信使用 FP16
    buffer_dtype=torch.float16,
)

# 用 FSDP 包装模型
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # 对应 ZeRO Stage 3
    auto_wrap_policy=auto_wrap_policy,
    mixed_precision=mixed_precision,
    device_id=local_rank,
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 训练循环
for batch in dataloader:
    inputs, labels = batch
    inputs = inputs.cuda(local_rank)
    labels = labels.cuda(local_rank)

    outputs = model(inputs)          # FSDP 自动处理 All-Gather / 释放
    loss = loss_fn(outputs, labels)
    loss.backward()                  # FSDP 自动处理 All-Gather / Reduce-Scatter
    optimizer.step()                 # 每张 GPU 更新本地 1/N 参数
    optimizer.zero_grad()

从代码可以看到,FSDP 的用户接口与标准的 DDP 训练几乎完全一致。核心差异仅在于将 DistributedDataParallel 替换为 FullyShardedDataParallel,并指定分片策略和包装策略。所有 All-Gather、Reduce-Scatter 和参数释放操作都被封装在 FSDP 的前向/反向钩子中,对用户透明。


10.6.8 ZeRO 与 FSDP 的对比

ZeRO(DeepSpeed)和 FSDP(PyTorch)在核心原理上是等价的,但在工程定位和适用场景上有所区分:

维度DeepSpeed ZeROPyTorch FSDP
实现方式独立库,通过训练引擎包装PyTorch 原生模块
分片粒度Stage 1/2/3 三级可选ShardingStrategy 枚举,对应 Stage 2/3/DDP
Offload 能力成熟的 CPU/NVMe Offload支持 CPU Offload(FSDP2 改进中)
3D 并行集成原生支持 3D 并行(TP + PP + ZeRO)需与 DeviceMesh/DTensor 配合
生态兼容性需适配 DeepSpeed 引擎 API与 PyTorch AMP、编译器、Profiler 无缝集成
配置方式JSON 配置文件Python API
社区与维护微软维护PyTorch 核心团队维护

选择建议。

  • 如果已使用 DeepSpeed 生态(如 Megatron-DeepSpeed)进行 3D 并行训练,ZeRO 是自然选择,其 Offload 能力和配置灵活性更成熟。
  • 如果希望保持纯 PyTorch 技术栈、减少外部依赖,或需要与 PyTorch 的编译器(torch.compile)、分布式张量(DTensor)等新特性深度配合,FSDP 是更优方案。
  • 在大多数场景下,两者的训练效率和显存节省效果非常接近。选择更多取决于工程栈的偏好而非性能差异。

10.6.9 ZeRO/FSDP 在混合并行中的角色

在第 10.5 节末尾提到的 4D 并行架构中,ZeRO/FSDP 扮演着数据并行维度的显存优化器角色。典型的混合并行配置如下:

                  ┌───────────────────────────────────────┐
                  │           4D 并行配置                   │
                  │                                       │
    节点内         │  TP(张量并行):节点内 NVLink 高带宽       │
    ───────       │  → 切分单层权重矩阵                      │
                  │                                       │
    跨节点         │  PP(流水线并行):跨节点 InfiniBand       │
    ───────       │  → 切分模型层                           │
                  │                                       │
    DP 组内        │  ZeRO / FSDP:在 DP 副本之间             │
    ───────       │  → 分片优化器状态 / 梯度 / 参数            │
                  │                                       │
    MoE 模型      │  EP(专家并行):跨 GPU 分布专家            │
    ───────       │  → All-to-All 路由 Token                │
                  └───────────────────────────────────────┘

在这种配置中,同一数据并行组(DP Group)内的 GPU 处理不同的数据分片、持有相同的模型参数副本。传统 DDP 要求每张 GPU 完整存储模型状态,而 ZeRO/FSDP 在 DP 组内部进行分片,使得即使在 TP + PP 已经切分模型的基础上,DP 维度也不再有显存冗余。

实际应用中,4D 并行通常使用 ZeRO Stage 1 或 Stage 2(而非 Stage 3),因为 TP 和 PP 已经分担了大部分模型切分工作,DP 维度只需分片优化器状态和梯度即可。Stage 3 的额外 All-Gather 通信在与 TP/PP 叠加时会产生较大的通信压力。


本节小结

ZeRO 和 FSDP 从显存状态切分的角度解决了数据并行中的冗余问题,是当前大模型训练基础设施的核心组件:

  • ZeRO Stage 1 分片优化器状态,在零额外通信的前提下节省约 4 倍显存。
  • ZeRO Stage 2 进一步分片梯度,通信量不变,显存节省约 8 倍。
  • ZeRO Stage 3 / FSDP FULL_SHARD 对全部模型状态进行完全分片,显存节省与 GPU 数量成正比,代价是 1.5 倍于 DDP 的通信量。
  • ZeRO-Offload / ZeRO-Infinity 通过 CPU/NVMe 卸载进一步突破 GPU 显存上限。

与模型并行(TP、PP)不同,ZeRO/FSDP 保留了数据并行易于使用的核心优势——用户无需手动切分模型结构,只需指定分片策略,框架即可自动处理所有通信和显存管理。正是这种"对用户透明的显存优化"特性,使得 ZeRO/FSDP 成为了从学术实验到工业训练的通用标配技术。