First-Passage Prediction of Grokking Delay: ACalibrated Law under AdamW with Causal Validation
Provides the first quantitative prediction of grokking delay for mechanistic interpretability researchers, though empirical scope is limited to small algorithmic tasks and transfer to large-scale models is unverified.
The authors derive a closed-form law predicting grokking delay under AdamW, achieving 17.7% MAPE on held-out runs over a 41x delay range. They also establish a quantile-margin theorem linking delay to norm separation and angular reachability, with causal interventions confirming the mechanism.
We give the first quantitative prediction of grokking delay under AdamW. Treating the delay as a first-passage time, we derive a closed-form law T_grok - T_mem = (1 / 2 kappa_LL eta lambda) log(V_mem / V_star), where V_t = ||theta_t||^2 is the squared parameter norm, V_star is an architecture-dependent threshold, and kappa_LL absorbs the AdamW correction to the clean-SGD contraction rate 2 eta lambda. Calibrating (kappa_LL, V_star) on a single hyperparameter cell predicts grokking delays on 26 held-out runs with MAPE 17.7% over a 41x delay range; the law generalises to MLPs (MAPE 18.0%, N=34) and degrades to 23.3% on cross-task extension (N=46, 43.5x range), with a structured residual in which V_star / V_mem stays comparatively stable within architecture (CV about 14% on the 1L transformer). First-passage of V_t is necessary but not sufficient. A quantile-margin theorem establishes that positive delay requires both norm separation V_mem > V_post and angular reachability of a threshold alpha_star = arcsin(C / V_T_mem^(1/2)), where C is computable from the empirical NTK feature map and the validation-margin quantile. Calibrating C on modulus p=89 predicts alpha_star = 47.2 degrees at p=97 (observed 47.8 degrees, error 1.3%) as a prior cross-cell prediction. Causal interventions that freeze the norm or remove weight decay at memorisation eliminate grokking (0/6 vs. 3/3 baseline), trapping the angular displacement near 12 degrees. kappa_LL is empirically measured per architecture rather than derived from (beta_1, beta_2, epsilon); within-architecture CV stays at most 15% across four architectures, but values differ by about 2x between architectural variants beyond depth alone. Empirical scope is algorithmic tasks (modular arithmetic, sparse parity) under AdamW; whether the law transfers to natural-language scale models is open.