2.5 注意力的代价:复杂度与局限
注意力机制赋予了 Transformer 强大的表示能力,但这种能力并非没有代价。理解注意力的复杂度特性和内在局限,对于后续理解各种优化技术至关重要。
2.5.1 计算复杂度分析
在之前的讨论中,为了便于理解,通常将 Q、K、V 视作与单个词对应的 向量。但在实际的计算硬件(如 GPU)上,为了实现高度的并行计算效率,序列中所有 $n$ 个词的查询向量会被按行拼接为一个完整的 查询矩阵 $Q \in \mathbb{R}^{n \times d_k}$(键矩阵 $K$ 和值矩阵 $V$ 同理)。由此,注意力机制的工作过程从“逐词的向量匹配”转变为了一次性的大规模矩阵运算。
缩放点积注意力的核心计算是两次矩阵乘法:$QK^T$ 和 $AV$。
$QK^T$ 的计算:$Q \in \mathbb{R}^{n \times d_k}$ 与 $K^T \in \mathbb{R}^{d_k \times n}$ 相乘,结果为 $n \times n$ 矩阵,计算量为 $O(n^2 d_k)$。
$AV$ 的计算:$A \in \mathbb{R}^{n \times n}$ 与 $V \in \mathbb{R}^{n \times d_v}$ 相乘,结果为 $n \times d_v$ 矩阵,计算量为 $O(n^2 d_v)$。
因此,单头注意力的计算复杂度为 $O(n^2 d)$(其中 $d = d_k = d_v$),对序列长度 $n$ 是平方的。
这个平方复杂度产生了两个实际影响:
计算时间:处理长度为 $n$ 的序列的计算量与 $n^2$ 成正比。长度翻倍,计算量增加四倍。
内存占用:注意力分数矩阵 $A \in \mathbb{R}^{n \times n}$ 需要存储 $n^2$ 个元素。对于长度 $n = 8192$ 的序列,仅一层一个头的注意力矩阵就需要约 $8K8K4B = 256 MB$(FP32 精度),而完整模型可能有数十层和数十个头。
2.5.2 与其他架构的复杂度对比
下表对比了不同架构在关键复杂度指标上的表现:
自注意力
$O(n^2 d)$
$O(1)$
$O(1)$
$O(n^2)$
循环层
$O(n d^2)$
$O(n)$
$O(n)$
$O(n)$
卷积层
$O(k n d^2)$
$O(1)$
$O(\log_k n)$
$O(kn)$
图 2-5:三种架构的复杂度对比
这个对比揭示了一个重要的权衡:
自注意力用 $O(n^2)$ 的计算成本换来了 $O(1)$ 的路径长度和完全并行
循环层内存效率高($O(n)$),但路径长度和顺序操作数都是 $O(n)$
卷积层在局部特征提取上高效,但需要多层才能覆盖长距离
当 $n < d$(即序列长度小于模型维度)时——这在 NLP 任务中很常见——自注意力的 $O(n^2 d)$ 实际上小于循环层的 $O(n d^2)$。这解释了为什么 Transformer 在典型 NLP 任务的序列长度(数百到数千词元)上比 RNN 更快。
2.5.3 平方复杂度的实际影响
平方复杂度在不同序列长度下的实际影响差异巨大:
512
262K
0.5 MB
1×
2,048
4.2M
8 MB
16×
8,192
67M
128 MB
256×
32,768
1.07B
2 GB
4,096×
131,072
17.2B
32 GB
65,536×
当序列长度从 512 增长到 131,072(256 倍)时,计算量增长了 65,536 倍。这就是为什么早期 Transformer 通常将序列长度限制在 512 或 1024——更长的序列在硬件上根本不可行。
2.5.4 应对平方复杂度的思路
平方复杂度的瓶颈催生了大量后续研究,主要方向包括:
稀疏注意力:不是让每个位置关注所有位置,而是只关注一个子集。例如,Longformer 使用局部窗口注意力加少量全局注意力,将复杂度降至 $O(n)$。BigBird 在此基础上加入随机注意力连接。
线性注意力:通过核函数近似或矩阵分解,将 $O(n^2)$ 的注意力计算简化为 $O(n)$。Performer 和 Linear Transformer 是这一方向的代表。
IO 感知优化:Flash Attention 不改变注意力的数学定义,而是优化计算的内存访问模式,避免在 GPU 的高带宽内存中存储完整的 $n \times n$ 注意力矩阵。
替代架构:状态空间模型(如 Mamba)和线性循环网络从根本上放弃了注意力的全连接模式,用线性递推代替,实现 $O(n)$ 的计算与内存复杂度。
最后更新于
