Fast Neural Tangent Kernel Alignment, Norm and Effective Rank via Trace Estimation
This provides a faster analysis tool for researchers and practitioners working with NTK in machine learning, though it is incremental as it builds on existing trace estimation techniques.
The paper tackles the computational infeasibility of computing the full Neural Tangent Kernel (NTK) matrix, especially for recurrent architectures, by introducing a matrix-free method using trace estimation to rapidly compute its trace, Frobenius norm, effective rank, and alignment, achieving speedups of many orders of magnitude.
The Neural Tangent Kernel (NTK) characterizes how a model's state evolves over Gradient Descent. Computing the full NTK matrix is often infeasible, especially for recurrent architectures. Here, we introduce a matrix-free perspective, using trace estimation to rapidly analyze the empirical, finite-width NTK. This enables fast computation of the NTK's trace, Frobenius norm, effective rank, and alignment. We provide numerical recipes based on the Hutch++ trace estimator with provably fast convergence guarantees. In addition, we show that, due to the structure of the NTK, one can compute the trace using only forward- or reverse-mode automatic differentiation, not requiring both modes. We show these so-called one-sided estimators can outperform Hutch++ in the low-sample regime, especially when the gap between the model state and parameter count is large. In total, our results demonstrate that matrix-free randomized approaches can yield speedups of many orders of magnitude, leading to faster analysis and applications of the NTK.