A Fast, Well-Founded Approximation to the Empirical Neural Tangent Kernel
This provides a justified approximation for eNTKs, addressing a computational bottleneck in understanding neural network representations, though it is incremental as it builds on existing approximations.
The paper tackled the computational inefficiency of empirical neural tangent kernels (eNTKs) for networks with multiple outputs by proving that the 'sum of logits' approximation converges to the true eNTK at initialization for networks with a wide final readout layer, and experiments showed its quality across various settings.
Empirical neural tangent kernels (eNTKs) can provide a good understanding of a given network's representation: they are often far less expensive to compute and applicable more broadly than infinite width NTKs. For networks with O output units (e.g. an O-class classifier), however, the eNTK on N inputs is of size $NO \times NO$, taking $O((NO)^2)$ memory and up to $O((NO)^3)$ computation. Most existing applications have therefore used one of a handful of approximations yielding $N \times N$ kernel matrices, saving orders of magnitude of computation, but with limited to no justification. We prove that one such approximation, which we call "sum of logits", converges to the true eNTK at initialization for any network with a wide final "readout" layer. Our experiments demonstrate the quality of this approximation for various uses across a range of settings.