7.5 激活重计算:用时间换空间的艺术

在讨论混合精度等技术时,我们反复提到“显存瓶颈”。在大模型训练中,显存不仅被模型参数、优化器状态和梯度占用,还有一个不可忽视的消耗大户——激活值(Activations)

7.5.1 为什么激活值占用大量显存

在前向传播过程中,每一层的输出(激活值)都必须保存在显存中。这是因为反向传播计算梯度时需要重用这些前向激活值。

对于标准的 Transformer 层,随着批次大小(Batch Size)和序列长度(Sequence Length)的增加,激活值的显存占用会线性增长。在长序列训练或使用了大批次的情况下,激活值占用的显存甚至会远远超过模型参数本身的显存占用。如果不加干预,这就成为训练更大模型或更长上下文的致命瓶颈。

7.5.2 梯度检查点机制

激活重计算(Activation Checkpointing,也常被称为 Gradient Checkpointing)是解决激活显存瓶颈的标准方案。其核心思想非常直观:放弃保存部分前向传播的激活值,在反向传播需要用到时,再重新计算一遍。

具体做法是:

  1. 前向传播:将模型划分为多个“段”(Segments)。在层与层的边界处保存少量的“检查点”激活值,而丢弃段内部所有的中间激活值。

  2. 反向传播:当反向传播到达某个段时,从该段保存的“检查点”出发,额外执行一次局部的前向传播,重新计算出所需的中间激活值,然后计算梯度,最后再次丢弃这些临时激活值。

7.5.3 内存与计算的权衡

通过激活重计算,显存占用可以从与层数 $L$ 成正比($O(L)$)大幅降低至 $O(\sqrt{L})$(当采用均匀检查点策略时)。

这种显存节省的代价是增加了约 33% 的额外计算量(因为相当于执行了 1 次完整前向 + 1 次局部前向 + 1 次完整反向)。但在实际的大规模训练中,这通常是完全值得的:因为节省下来的巨大显存可以被用来大幅提高批次大小。而更大的批次大小不仅有利于提高梯度的统计稳定性,更能提高现代 GPU 上矩阵运算的算力利用率(MFU),这种局部效率的提升往往能抵消甚至超过额外计算带来的时间开销。

7.5.4 现代重计算策略

如今的模型训练并不满足于简单的“全量重计算”。Megatron-LM 等框架引入了选择性激活重计算(Selective Activation Checkpointing):只对那些“显存占用极大但重新计算很快”的操作(如注意力机制中的 Softmax 和 Dropout)进行重计算,而对那些“显存占用小但计算耗时”的操作(如稠密矩阵乘法)进行正常保存。这种精细化的策略在几乎不增加计算时间的情况下,能实现显著的显存节省。

最后更新于