6.4 批次与序列长度:效率与质量的平衡

批次大小(Batch Size)和序列长度(Sequence Length)是影响训练效率、模型质量和显存占用的两个关键变量。

6.4.1 批次大小的影响

大批次的优势:更稳定的梯度估计(减少噪声)、更高的 GPU 利用率(更好的并行效率)、更快的训练吞吐量。

大批次的风险:过大的批次可能导致泛化性能下降——梯度太“平滑”可能使模型更容易陷入尖锐的局部最优,而适度的梯度噪声实际上有助于找到更平坦(泛化更好)的最优区域。

原始 Transformer 使用约 25,000 词元/批次进行训练。现代大语言模型的批次规模要大得多——GPT-3 使用 320 万词元/批次,Llama 3 更是从初始阶段的 400 万词元/批次逐渐增大到 1600 万词元/批次。

6.4.2 序列长度的管理

序列长度直接影响注意力计算的内存占用($O(n^2)$)和模型能捕捉的最大上下文范围。

填充与打包:同一批次中的序列需要统一长度。传统做法是将短序列填充到最大长度,但这浪费了大量计算。序列打包(Sequence Packing)是一种更高效的方法——将多个短序列拼接在一起填满一个完整的训练样本,配合特殊的注意力掩码确保不同序列之间不交互。

动态长度策略:一些训练框架支持按长度分组(将相近长度的序列放入同一批次),或在训练过程中逐步增加序列长度——先用短序列快速迭代,后期切换到长序列以学习长距离依赖。

6.4.3 显存占用分析

训练 Transformer 时的显存主要被四部分占用:

组成部分
占用比例
说明

模型参数

~25%

权重矩阵

梯度

~25%

与参数同等大小

优化器状态

~40%

Adam 需要存储一阶/二阶矩(2倍参数量)

激活值

~10%+

用于反向传播,随序列长度和批次大小增长

表 6-1:训练时显存占用的典型分布

这个分析表明,优化器状态是最大的显存消耗者——这正是 ZeRO 等优化技术的主要优化目标(见第七章)。

梯度检查点(Gradient Checkpointing,也叫激活重计算)是一种以计算换显存的技术:不保存所有层的激活值,而是只保存部分层的激活值,其他层的激活值在反向传播时重新计算。这通常能将激活值的显存占用减少到 $O(\sqrt{L})$($L$ 为层数),代价是约增加 30% 的计算时间。

最后更新于