梯度峰值从何而来
一个平静的小批量 & 一个冲击性的小批量
大多数小批量产生的梯度幅度合理。对于已经大致拟合数据的模型,交叉熵损失保持在狭窄范围内;反向传播将该信号以相似大小的梯度形式回传。
有些小批量则不然。梯度峰值的三个来源:
1. 异常值示例。 一个具有极度罕见令牌组合的单一序列会产生远超平均值的损失 & 远超平均值的梯度。
2. 数值边缘情况。 接近零的softmax分母、产生NaN的layernorm、FP16溢出。每一种都可能产生比典型值大几个数量级的梯度。
3. 分布偏移。 在单次训练运行中切换数据源会用新分布震撼模型。ANDREA的bandit每7到42步重新调整源权重。每次切换都是一个小分布偏移。
ANDREA-120M v1:尖峰级联
v1没有梯度裁剪。bandit每7到42步的源转换向模型提供短暂的repo-docs(列表结构)、然后gutenberg(长篇散文)、然后hermes3-general(问答)爆发。每次转换都会产生梯度尖峰:每个尖峰将权重推入120M规模的退化吸引子。
关键实证事实。 ANDREA-12M 在不使用裁剪的情况下存活了相同的匪徒攻击。小型权重矩阵对梯度冲击保持鲁棒;一个坏批次无法将 12M 参数推入失控吸引子,而对于 120M 参数则可以。模型规模越大,裁剪的重要性越高。
全局 L2 范数裁剪
两种选择:逐张量或全局
两种限制梯度幅度的方法:
逐张量裁剪。 独立裁剪每个梯度张量。嵌入梯度被裁剪到其自身范数;注意力梯度被裁剪到其自身范数。简单,但会扭曲相对尺度:一个张量中的小峰值(现在梯度为零)与另一个张量中的巨大梯度(未受影响)配对。
全局 L2 范数裁剪。 将所有梯度视为一个大向量。计算每个参数的总 L2 范数。如果范数超过 max_norm,则按相同因子缩放每个梯度。保留张量间的相对幅度。
ANDREA 使用全局裁剪。Pascanu 等人(2013)通过实验证明,在 Transformer 训练中,全局裁剪优于每个张量的裁剪。
数学公式
计算全局 L2 范数:
norm = sqrt(sum over all params of g_i^2)
如果 norm <= max_norm,梯度原样通过。如果 norm > max_norm,则将每个梯度按 max_norm / norm 缩放:
g_i_clipped = g_i * (max_norm / norm)
缩放后,新范数恰好等于 max_norm。ANDREA 使用 max_norm = 1.0。
计算缩放因子
为什么梯度范数计算需要三个内核
朴素算法无法在 GPU 上运行
全局 L2 范数计算的伪代码:
total = 0
对于每个参数 p:
对于 p.grad 中的每个元素 g:
total += g * g
norm = sqrt(total)
在 GPU 上,这个朴素的循环失败有两个原因:
1. 顺序累积。 单个 total 累加器迫使每个线程等待其他所有线程,破坏了 GPU 并行性。
2. 异构张量。 ANDREA-120M 具有形状差异很大的张量:嵌入 (8449 x 768)、注意力 QKV (768 x 768)、layernorm (768)。一个内核无法高效迭代所有形状。
ANDREA 的三内核流水线
将工作拆分为 microgpt_cuda.cu 中的三个 CUDA 内核:
内核 1: k_grad_norm_partial。 对于每个参数张量,计算平方和的部分和。每个线程块归约张量的一个块;结果写入一个小临时缓冲区。并行性:每个块一个块,所有张量跨越数百个块。
内核 2: k_grad_norm_final。 将暂存缓冲区归约到一个标量。然后取其平方根。一个小型内核,运行时间在微秒级别。
内核 3: k_grad_scale。 如果 norm > max_norm,计算 scale = max_norm / norm 并将每个梯度元素乘以 scale。一次遍历所有梯度张量,完全并行。
顺序很重要:Pre-Adam
裁剪管道在 AdamW 更新 m、v 或任何参数之前运行。为什么?
裁剪的梯度会输入到 AdamW 的指数移动平均中。如果允许峰值流入 m 和 v,它会破坏这些运行平均值,并在峰值后许多步骤中减缓恢复。Adam 前的裁剪将峰值的影响限制在单个坏步骤中。
为什么是三个内核,而不是一个?
无剪切如何杀死 v1
强盗源每 7 到 42 步转换一次
ANDREA 的强盗以阶段运作。每个阶段持续 7、14、21、28 或 42 步(随机选择)。在每个阶段边界,源权重会转移:也许 repo-docs 从 0.1 跳到 0.6,gutenberg 从 0.4 降到 0.1,hermes3-general 从 0.5 升到 0.7。
每次转换都是对模型的分布冲击。损失会短暂飙升。梯度随之飙升:一个原本针对 gutenberg 风格散文的损失最小化模型,现在看到 repo-docs 风格的列表结构,梯度携带的修正信号可能达到典型幅值的 10 倍或 100 倍。
v1 失败模式
没有裁剪,那些 10-100 倍的梯度峰值会流入 AdamW 的 m 和 v 平均值。AdamW 的平滑机制意味着峰值效应在实际坏批次之后会持续许多步。结合没有权重衰减(v1 中的 vanilla Adam),峰值驱动的权重更新会在阶段中累积,直到权重漂移到一个退化的吸引子:一个 token 的 logit 主导 softmax,采样输出就是那个 token,训练上下文包含那个 token,梯度强化了那个 token。重复锁定。
v2 稳定性
v2 添加了裁剪,max_norm = 1.0,同时引入 AdamW 和 LR 预热。峰值对 m 和 v 的影响被限制;权重无法以超过 lr max_norm = 0.0003 1.0 = 0.0003 的速度每参数每步漂移。在峰值时,阶段转换仍会产生峰值,但这些峰值在到达优化器之前就被限制住了。
结果:v2(在数据过滤 v2.5 和 v3 润色之后)达到了事实回忆、多段落连贯性,以及生物学和信号处理样本上 9.5/10 的外部评分。
容量-脆性耦合
相同的强盗。相同的数据。相同的超参数,除了裁剪。为什么 12M 在没有裁剪的情况下存活,而 120M 崩溃了?
两个复合因素:
1. 更大的权重矩阵存储更多吸引子。 一个 768x768 的注意力投影有 590K 参数;即使是每个参数的小漂移也会产生注意力行为的有意义的改变。一个 384x384 的注意力投影有 147K 参数,并停留在更受约束的子空间中。
2. 更多层意味着更多乘性交互。 v3 有 12 个 transformer 层(相比 12M 的 6 个)。峰值通过 12 层复合非线性传播;每一层都可以放大前一层的漂移。
脆弱性随着容量复合。超过某个规模阈值后,裁剪变得强制;ANDREA 将该阈值置于 12M 与 120M 参数之间。
诊断 v1 级联效应
裁剪还在哪里适用?
相关活动
三个兄弟活动与裁剪相关联:
- 活动 10: AdamW。 裁剪保护 AdamW 的 m 和 v 免受激增污染。没有裁剪,一个坏批次会使优化器状态在 50+ 步内损坏。
- 活动 11: LR 预热。 预热抑制 lr;裁剪抑制 g。结合使用:在第 1 步,最坏情况参数更新为 lr_after_warmup max_norm = 1.5e-7 1.0 = 1.5e-7,而没有任一保护措施则为 0.0003 * 50 = 0.015。最坏情况早期更新幅度的 100,000 倍减少。
- 活动 14:多臂赌博机。 赌博机阶段长度(7 到 42 步)特别短,以防止任何单一来源主导;裁剪使得那些频繁转换安全。
裁剪是 Transformer 训练中最便宜的稳定性收益:3 个小型 CUDA 内核,每步微秒级,决定 120M+ 模型是收敛还是崩溃。