Numerical Fragility in Transformers: A Layer-wise Theory for Explaining, Forecasting, and Mitigating Instability
This addresses numerical fragility in Transformers for machine learning practitioners, offering incremental improvements in stability diagnostics and mitigation.
The paper tackles the problem of numerical instability in low-precision Transformer training by developing a layer-wise theory that predicts error growth, with results including a predictor tracking mismatches across conditions, an early-warning signal leading error spikes by 16-24 steps, and a mitigation tweak reducing mean tail-loss by approximately 0.010.
Transformers trained in low precision can suffer forward-error amplification. We give a first-order, module-wise theory that predicts when and where errors grow. For self-attention we derive a per-layer bound that factorizes into three interpretable diagnostics: a score-scale ratio $κ_{\rm score}$, a rowwise softmax sensitivity $κ_{\rm softmax}$, and value conditioning $κ(V)$. We prove a residual relaxation inequality showing that residual blocks attenuate depth-wise accumulation, and we introduce a precision- and width-aware LayerNorm indicator $ρ_{\rm LN}$ with a matching first-order bound in the $ε$-dominated regime. These pieces yield a unified forward-stability bound whose right-hand side is directly estimable during training. On Tiny-ViT/CIFAR-10 we evaluate the bound and components. (1) The combined predictor $κ_{\rm softmax},(1+κ_{\rm score}),κ(V),|W_O|2+κ{\rm eff}+C_{\rm LN}$ tracks FP32$\leftrightarrow$LP mismatches across seeds, widths, and precisions; scaling by $ε_{\rm mach}$ collapses mixed-precision points. (2) The time-series maximum of $κ_{\rm softmax}$ acts as an early-warning signal, leading error spikes by 16-24 steps (corr. 0.65-0.82; permutation $p!\approx!10^{-3}$; Precision@K 0.89-1.00). (3) Guided by $ρ_{\rm LN}$, a small LayerNorm-$ε$ tweak targeting $ρ_\star$ gives consistent stabilization (mean tail-loss $\downarrow\ \approx0.010$ at $ρ_\star!=!0.6$, cap$=10^{-2}$) with negligible overhead. Overall, our theory supplies actionable, unitless diagnostics that (i) explain when self-attention is fragile, (ii) forecast instability, and (iii) motivate a minimally invasive mitigation.