在自回归解码中,生成第 $t$ 个词元时,需要计算它与之前所有 $t-1$ 个词元的注意力。如果每次都从头计算所有词元的 Key 和 Value,计算量将随生成长度呈平方增长。
KV 缓存的核心观察是:之前词元的 Key 和 Value 在后续步骤中不会改变(因为因果掩码确保它们不受未来词元的影响)。因此,只需在首次计算后将它们缓存到显存中,后续步骤直接复用即可。
每生成一个新词元,只需:
计算新词元的 Q、K、V
将新的 K、V 追加到缓存
用新的 Q 与所有缓存的 K 计算注意力权重
用权重对所有缓存的 V 加权求和
这将每步的计算从 $O(t^2 \cdot d)$ 降至 $O(t \cdot d)$。
KV 缓存虽然节省了计算,但引入了显著的显存开销。对于一个具有 $L$ 层、$H$ 个头、head 维度 $d_h$ 的模型,缓存 $t$ 个词元需要:
KV 缓存大小=2×L×H×dh×t×bytes/element\text{KV 缓存大小} = 2 \times L \times H \times d_h \times t \times \text{bytes/element}KV 缓存大小=2×L×H×dh×t×bytes/element
对于 Llama 2-70B(80 层、64 头、128 维),缓存 4096 个词元需要约 5 GB(FP16)。在高并发场景下,多个用户请求的 KV 缓存会迅速填满显存。这正是 GQA 和 PagedAttention 等技术的优化目标。
第二章介绍了多头注意力中每个头有独立的 Q、K、V 投影。分组查询注意力(Grouped-Query Attention,GQA)让多个查询头共享同一组 K 和 V 头。
例如,Llama 2-70B 使用 64 个查询头但只有 8 个 KV 头(每 8 个查询头共享一组 KV)。这将 KV 缓存减小为原始多头注意力的 1/8,且对模型质量的影响极小。
GQA 是介于多头注意力(MHA,每个查询头一组独立 KV)和多查询注意力(MQA,所有查询头共享一组 KV)之间的折中方案。它在推理效率和模型质量之间取得了很好的平衡,成为现代大语言模型的标准配置。
最后更新于1天前