10.3 Flash Attention:IO 感知的算法设计

10.3.1 为什么标准注意力慢

标准注意力实现的瓶颈不在于浮点计算量,而在于内存访问模式。计算 $QK^T$ 会在 GPU 的高带宽内存(HBM)中生成一个 $n \times n$ 的注意力矩阵。这个矩阵的读写构成了大量的非必要内存访问。

10.3.2 Flash Attention 的核心思想

Flash Attention 由 Dao 等人提出,通过分块计算(Tiling)和核内重计算(Recomputation)避免将完整的注意力矩阵写入 HBM:

  1. 将 Q、K、V 分成小块,适配 GPU 的片上 SRAM

  2. 在 SRAM 中完成注意力计算(块级 Softmax)

  3. 直接输出结果到 HBM,从不在 HBM 中存储完整的 $n \times n$ 注意力矩阵

这使得 Flash Attention 的 IO 复杂度从 $O(n^2)$ 降至 $O(n^2 d / M)$($M$ 为 SRAM 大小),在实践中带来 2-4 倍的速度提升和显著的显存节省。

10.3.3 Flash Attention 2:并行度与效率的提升

Flash Attention 2 在 Flash Attention 的基础上进行了多项关键优化:

  1. 减少非矩阵乘 FLOPs:重新组织了在线 Softmax 的计算流程,消除了大量缩放和边界检查等非矩阵乘操作,使更多的计算时间花在 GPU Tensor Core 擅长的矩阵乘法上

  2. 优化 Warp 级并行:在 Warp 之间并行处理不同的注意力头,而非让多个 Warp 协作处理同一个头,大幅减少了 Warp 间的同步和通信开销

  3. 序列长度维度并行:在 Q 的序列长度维度上增加并行度,使得长序列场景下 GPU 的 SM 占用率更高

这些优化使 Flash Attention 2 在 A100 上达到了理论 FLOPs 的 50-73% 利用率(取决于序列长度),相比 Flash Attention 1 提升约 2 倍,并已成为所有主流推理和训练框架的标准组件。

10.3.4 Flash Attention 3:面向 Hopper 架构的深度优化

尽管 Flash Attention 2 表现出色,但在 NVIDIA H100(Hopper 架构)上的利用率却仅约 35%。这说明 Hopper 引入的新硬件特性没有被充分利用。Flash Attention 3 由 Tri Dao 等人于 2024 年提出,通过三项关键技术实现了对 Hopper 架构的深度适配。

异步流水线

Hopper 架构引入了张量内存加速器(Tensor Memory Accelerator,TMA),可以独立于计算单元进行异步数据搬运。Flash Attention 3 利用 Warp 特化(Warp Specialization)技术,将 Warp 分为“生产者”和“消费者”两个角色:

  • 生产者 Warp:通过 TMA 将下一块 K、V 数据从 HBM 异步加载到 SRAM

  • 消费者 Warp:同时在 SRAM 中对当前数据块执行矩阵乘法

这种流水线设计使计算和数据搬运全程重叠,消除了 Flash Attention 2 中“加载-计算-加载-计算”的串行等待。

GEMM-Softmax 交错执行

在标准实现中,每个数据块的处理遵循严格的顺序:先执行 $QK^T$ 矩阵乘(GEMM),然后计算 Softmax,最后执行与 $V$ 的矩阵乘。Softmax 中的指数和归一化操作无法利用 Tensor Core,成为流水线的瓶颈。

Flash Attention 3 将当前块的 Softmax 与下一块的 GEMM 交错执行。当 Tensor Core 忙于计算下一块的矩阵乘时,CUDA Core 同步处理上一块的 Softmax——两种不同类型的计算单元并行工作,进一步提升了吞吐量。

FP8 低精度支持

Flash Attention 3 利用 Hopper 的第四代 Tensor Core 原生支持 FP8 精度。通过以下技术在性能翻倍的同时控制精度损失:

  • 块级量化(Block Quantization):对每个小块独立计算缩放因子,避免全局量化带来的动态范围问题

  • 非相干处理(Incoherent Processing):在量化前对输入施加随机正交变换,使异常值分散到更多维度,减少最大量化误差

性能表现

在 NVIDIA H100 上,Flash Attention 3 取得了显著的性能提升:

精度
吞吐量
H100 利用率
相比 FA2

FP16

~740 TFLOPS

~75%

1.5-2.0×

FP8

~1.2 PFLOPS

数值误差降低 2.6×

图 10-2:Flash Attention 3 在 H100 上的性能表现

Flash Attention 3 的演进揭示了一个重要趋势:算法设计必须与硬件架构协同演化。每一代新硬件都会引入新的计算原语和内存层级,只有深入理解这些硬件特性,才能充分释放算力。这一思想在 Blackwell 架构的 FP4 支持中得到了进一步延伸(详见 11.3 节)。

最后更新于