4.1 正弦位置编码:频率与外推的直觉
4.1.1 编码公式
4.1.2 为什么使用正弦函数
4.1.3 频率分解的可视化
import torch
import matplotlib.pyplot as plt
import math
d_model = 64 # 编码维度
max_pos = 100 # 位置数量
# 计算正弦位置编码
pe = torch.zeros(max_pos, d_model)
position = torch.arange(0, max_pos).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 左图:位置编码热力图
im = axes[0].imshow(pe.numpy().T, aspect="auto", cmap="RdBu_r",
origin="lower")
axes[0].set_xlabel("位置 (pos)")
axes[0].set_ylabel("维度 (i)")
axes[0].set_title("正弦位置编码热力图(PE 矩阵)")
plt.colorbar(im, ax=axes[0])
# 右图:选取 4 个不同频率通道的波形
channels = [0, 10, 20, 30] # 从高频到低频
for ch in channels:
freq = 1.0 / (10000 ** (ch / d_model))
axes[1].plot(pe[:, ch].numpy(),
label=f"维度 {ch}(频率 {freq:.4f})")
axes[1].set_xlabel("位置 (pos)")
axes[1].set_ylabel("编码值")
axes[1].set_title("不同频率通道的波形对比")
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("sinusoidal_pe_visualization.png", dpi=150)
plt.show()
4.1.4 相对位置的线性表示
4.1.5 外推能力
4.1.6 固定 vs 可学习:原始论文的对比
最后更新于
