6.4 批次与序列长度:效率与质量的平衡
批次大小(Batch Size)和序列长度(Sequence Length)是影响训练效率、模型质量和显存占用的两个关键变量。
6.4.1 批次大小的影响
大批次的优势:更稳定的梯度估计(减少噪声)、更高的 GPU 利用率(更好的并行效率)、更快的训练吞吐量。
大批次的风险:过大的批次可能导致泛化性能下降——梯度太“平滑”可能使模型更容易陷入尖锐的局部最优,而适度的梯度噪声实际上有助于找到更平坦(泛化更好)的最优区域。
原始 Transformer 使用约 25,000 词元/批次进行训练。现代大语言模型的批次规模要大得多——GPT-3 使用 320 万词元/批次,Llama 3 更是从初始阶段的 400 万词元/批次逐渐增大到 1600 万词元/批次。
6.4.2 序列长度的管理
序列长度直接影响注意力计算的内存占用($O(n^2)$)和模型能捕捉的最大上下文范围。
填充与打包:同一批次中的序列需要统一长度。传统做法是将短序列填充到最大长度,但这浪费了大量计算。序列打包(Sequence Packing)是一种更高效的方法——将多个短序列拼接在一起填满一个完整的训练样本,配合特殊的注意力掩码确保不同序列之间不交互。
动态长度策略:一些训练框架支持按长度分组(将相近长度的序列放入同一批次),或在训练过程中逐步增加序列长度——先用短序列快速迭代,后期切换到长序列以学习长距离依赖。
6.4.3 显存占用分析
训练 Transformer 时的显存主要被四部分占用:
模型参数
~25%
权重矩阵
梯度
~25%
与参数同等大小
优化器状态
~40%
Adam 需要存储一阶/二阶矩(2倍参数量)
激活值
~10%+
用于反向传播,随序列长度和批次大小增长
表 6-1:训练时显存占用的典型分布
这个分析表明,优化器状态是最大的显存消耗者——这正是 ZeRO 等优化技术的主要优化目标(见第七章)。
梯度检查点(Gradient Checkpointing,也叫激活重计算)是一种以计算换显存的技术:不保存所有层的激活值,而是只保存部分层的激活值,其他层的激活值在反向传播时重新计算。这通常能将激活值的显存占用减少到 $O(\sqrt{L})$($L$ 为层数),代价是约增加 30% 的计算时间。
最后更新于
