2.5 注意力的代价:复杂度与局限

注意力机制赋予了 Transformer 强大的表示能力,但这种能力并非没有代价。理解注意力的复杂度特性和内在局限,对于后续理解各种优化技术至关重要。

2.5.1 计算复杂度分析

在之前的讨论中,为了便于理解,通常将 Q、K、V 视作与单个词对应的 向量。但在实际的计算硬件(如 GPU)上,为了实现高度的并行计算效率,序列中所有 $n$ 个词的查询向量会被按行拼接为一个完整的 查询矩阵 $Q \in \mathbb{R}^{n \times d_k}$(键矩阵 $K$ 和值矩阵 $V$ 同理)。由此,注意力机制的工作过程从“逐词的向量匹配”转变为了一次性的大规模矩阵运算。

缩放点积注意力的核心计算是两次矩阵乘法:$QK^T$ 和 $AV$。

$QK^T$ 的计算:$Q \in \mathbb{R}^{n \times d_k}$ 与 $K^T \in \mathbb{R}^{d_k \times n}$ 相乘,结果为 $n \times n$ 矩阵,计算量为 $O(n^2 d_k)$。

$AV$ 的计算:$A \in \mathbb{R}^{n \times n}$ 与 $V \in \mathbb{R}^{n \times d_v}$ 相乘,结果为 $n \times d_v$ 矩阵,计算量为 $O(n^2 d_v)$。

因此,单头注意力的计算复杂度为 $O(n^2 d)$(其中 $d = d_k = d_v$),对序列长度 $n$ 是平方的

这个平方复杂度产生了两个实际影响:

计算时间:处理长度为 $n$ 的序列的计算量与 $n^2$ 成正比。长度翻倍,计算量增加四倍。

内存占用:注意力分数矩阵 $A \in \mathbb{R}^{n \times n}$ 需要存储 $n^2$ 个元素。对于长度 $n = 8192$ 的序列,仅一层一个头的注意力矩阵就需要约 $8K8K4B = 256 MB$(FP32 精度),而完整模型可能有数十层和数十个头。

2.5.2 与其他架构的复杂度对比

下表对比了不同架构在关键复杂度指标上的表现:

架构
每层计算量
顺序操作数
最大路径长度
内存

自注意力

$O(n^2 d)$

$O(1)$

$O(1)$

$O(n^2)$

循环层

$O(n d^2)$

$O(n)$

$O(n)$

$O(n)$

卷积层

$O(k n d^2)$

$O(1)$

$O(\log_k n)$

$O(kn)$

图 2-5:三种架构的复杂度对比

这个对比揭示了一个重要的权衡:

  • 自注意力用 $O(n^2)$ 的计算成本换来了 $O(1)$ 的路径长度和完全并行

  • 循环层内存效率高($O(n)$),但路径长度和顺序操作数都是 $O(n)$

  • 卷积层在局部特征提取上高效,但需要多层才能覆盖长距离

当 $n < d$(即序列长度小于模型维度)时——这在 NLP 任务中很常见——自注意力的 $O(n^2 d)$ 实际上小于循环层的 $O(n d^2)$。这解释了为什么 Transformer 在典型 NLP 任务的序列长度(数百到数千词元)上比 RNN 更快。

2.5.3 平方复杂度的实际影响

平方复杂度在不同序列长度下的实际影响差异巨大:

序列长度
注意力矩阵大小
内存(FP16)
相对计算量

512

262K

0.5 MB

2,048

4.2M

8 MB

16×

8,192

67M

128 MB

256×

32,768

1.07B

2 GB

4,096×

131,072

17.2B

32 GB

65,536×

当序列长度从 512 增长到 131,072(256 倍)时,计算量增长了 65,536 倍。这就是为什么早期 Transformer 通常将序列长度限制在 512 或 1024——更长的序列在硬件上根本不可行。

2.5.4 应对平方复杂度的思路

平方复杂度的瓶颈催生了大量后续研究,主要方向包括:

稀疏注意力:不是让每个位置关注所有位置,而是只关注一个子集。例如,Longformer 使用局部窗口注意力加少量全局注意力,将复杂度降至 $O(n)$。BigBird 在此基础上加入随机注意力连接。

线性注意力:通过核函数近似或矩阵分解,将 $O(n^2)$ 的注意力计算简化为 $O(n)$。Performer 和 Linear Transformer 是这一方向的代表。

IO 感知优化:Flash Attention 不改变注意力的数学定义,而是优化计算的内存访问模式,避免在 GPU 的高带宽内存中存储完整的 $n \times n$ 注意力矩阵。

替代架构:状态空间模型(如 Mamba)和线性循环网络从根本上放弃了注意力的全连接模式,用线性递推代替,实现 $O(n)$ 的计算与内存复杂度。

这些优化方向将在第十章第十四章中详细讨论。

最后更新于