5.4 预训练数据:规模定律与数据质量的博弈
5.4.1 规模定律的发现
import torch
import matplotlib.pyplot as plt
# 模拟 Scaling Law 的幂律关系 L(x) = a * x^(-alpha) + L_inf
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 左图:L(N)——损失随模型参数量的变化
N = torch.logspace(7, 11, 200) # 10M 到 100B 参数
alpha_N = 0.076
L_inf_N = 1.69
a_N = 8.8
L_N = a_N * N ** (-alpha_N) + L_inf_N
axes[0].loglog(N.numpy(), L_N.numpy(), "b-", linewidth=2)
axes[0].set_xlabel("模型参数量 N")
axes[0].set_ylabel("交叉熵损失 L(N)")
axes[0].set_title("(a)损失随模型规模的幂律下降")
axes[0].grid(True, alpha=0.3, which="both")
axes[0].set_xlim(1e7, 1e11)
# 右图:L(D)——损失随训练数据量的变化
D = torch.logspace(8, 13, 200) # 100M 到 10T 词元
alpha_D = 0.095
L_inf_D = 1.69
a_D = 5.0
L_D = a_D * D ** (-alpha_D) + L_inf_D
axes[1].loglog(D.numpy(), L_D.numpy(), "r-", linewidth=2)
axes[1].set_xlabel("训练数据量 D(词元数)")
axes[1].set_ylabel("交叉熵损失 L(D)")
axes[1].set_title("(b)损失随数据规模的幂律下降")
axes[1].grid(True, alpha=0.3, which="both")
axes[1].set_xlim(1e8, 1e13)
plt.tight_layout()
plt.savefig("scaling_law_curves.png", dpi=150)
plt.show()
5.4.2 数据来源与构建
5.4.3 数据质量的关键影响
5.4.4 词元化的影响
最后更新于
