Weighted Neural Tangent Kernel: A Generalized and Improved Network-Induced Kernel
This work addresses a foundational issue in understanding neural network training for machine learning researchers, but it is incremental as it builds on the NTK framework.
The paper tackles the problem that the Neural Tangent Kernel (NTK) performs poorly in practice due to gradient descent limitations, by introducing the Weighted Neural Tangent Kernel (WNTK) to capture training dynamics under different optimizers. The result shows that WNTK outperforms NTK in numerical experiments, with proven stability and equivalence to neural network estimators.
The Neural Tangent Kernel (NTK) has recently attracted intense study, as it describes the evolution of an over-parameterized Neural Network (NN) trained by gradient descent. However, it is now well-known that gradient descent is not always a good optimizer for NNs, which can partially explain the unsatisfactory practical performance of the NTK regression estimator. In this paper, we introduce the Weighted Neural Tangent Kernel (WNTK), a generalized and improved tool, which can capture an over-parameterized NN's training dynamics under different optimizers. Theoretically, in the infinite-width limit, we prove: i) the stability of the WNTK at initialization and during training, and ii) the equivalence between the WNTK regression estimator and the corresponding NN estimator with different learning rates on different parameters. With the proposed weight update algorithm, both empirical and analytical WNTKs outperform the corresponding NTKs in numerical experiments.