7.2 ZeRO 优化:如何突破单卡显存限制

ZeRO(Zero Redundancy Optimizer)由 Microsoft DeepSpeed 团队提出,是解决数据并行显存冗余问题的核心技术。

7.2.1 冗余分析

在标准数据并行中,每张 GPU 都存储了完整的模型状态——参数、梯度和优化器状态。但实际上,每张 GPU 在每个训练步中只需要自己那份数据对应的梯度来计算当前步的更新。其余都是冗余的。

以一个 15 亿参数的模型(约 GPT-2 规模)为例,使用 FP16 训练和 Adam 优化器:

组成部分
大小
说明

FP16 参数

3 GB

$1.5B \times 2$ 字节

FP16 梯度

3 GB

与参数同大小

FP32 优化器状态

18 GB

参数副本 + 一阶矩 + 二阶矩

总计

24 GB

每张 GPU

表 7-1:每张 GPU 的显存占用明细

在 64 张 GPU 上做数据并行时,这 24 GB 在每张卡上都有一份完整拷贝——总计 1536 GB 的显存中有 23/24 是冗余的

7.2.2 ZeRO 的三个阶段

ZeRO 通过将模型状态分片(Shard)到多张 GPU 上来消除冗余:

ZeRO-1(优化器状态分片):每张 GPU 只持有 $1/K$ 的优化器状态。更新参数时,各卡负责更新自己那部分参数,然后通过 AllGather 同步更新后的参数。显存减少约 4 倍。

ZeRO-2(梯度分片):在 ZeRO-1 基础上,梯度也分片存储。每张 GPU 只保留与自己负责的参数对应的梯度,其余在 Reduce-Scatter 后丢弃。显存进一步减少约 2 倍。

ZeRO-3(参数分片):最彻底的方案——连参数本身也分片存储。每张 GPU 只持有 $1/K$ 的参数,在前向和反向传播时通过 AllGather 临时获取完整参数。显存减少与 GPU 数量成正比。

7.2.3 通信与效率权衡

ZeRO 通过增加通信量来换取显存节省:

方案
显存节省
通信量
实际开销

标准 DDP

基准

ZeRO-1

~4×

几乎无开销

ZeRO-2

~8×

几乎无开销

ZeRO-3

线性于 GPU 数

~1.5×

有一定开销

ZeRO-1/2 的通信量与标准 DDP 相同(通过精巧的通信调度实现),因此几乎没有性能损失。ZeRO-3 需要额外的 AllGather 操作来获取参数,通信量增加约 50%,但 DeepSpeed 通过预取和流水线化等优化将实际开销控制在可接受的范围内。

ZeRO 的出现使得数据并行能够训练远超单卡显存容量的模型,成为了大模型训练的基础设施级技术。

最后更新于