LGMLNov 1, 2024

Normalization Layer Per-Example Gradients are Sufficient to Predict Gradient Noise Scale in Transformers

arXiv:2411.00999v15 citationsh-index: 18Has CodeNIPS
Originality Incremental advance
AI Analysis

This work addresses the computational bottleneck of gradient noise scale estimation for training large transformer models, offering a practical speed-up.

The paper tackled the problem of efficiently estimating gradient noise scale (GNS) in transformer models by showing that per-example gradient norms from normalization layers alone are sufficient to predict total GNS, leading to a custom kernel that reduces training time by 18% on a Chinchilla-optimal language model.

Per-example gradient norms are a vital ingredient for estimating gradient noise scale (GNS) with minimal variance. Observing the tensor contractions required to compute them, we propose a method with minimal FLOPs in 3D or greater tensor regimes by simultaneously computing the norms while computing the parameter gradients. Using this method we are able to observe the GNS of different layers at higher accuracy than previously possible. We find that the total GNS of contemporary transformer models is predicted well by the GNS of only the normalization layers. As a result, focusing only on the normalization layer, we develop a custom kernel to compute the per-example gradient norms while performing the LayerNorm backward pass with zero throughput overhead. Tracking GNS on only those layers, we are able to guide a practical batch size schedule that reduces training time by 18% on a Chinchilla-optimal language model.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes