Training Multi-Layer Over-Parametrized Neural Network in Subquadratic Time
This addresses the computational bottleneck in training large neural networks, particularly for fine-tuning large language models, though it is incremental as it builds on existing over-parametrization theory.
The paper tackles the problem of training over-parametrized neural networks, which typically requires quadratic time per iteration due to large weight matrices, and achieves a subquadratic cost per iteration, specifically m^{2-Ω(1)}, by using a framework that incurs quadratic cost only during initialization.
We consider the problem of training a multi-layer over-parametrized neural network to minimize the empirical risk induced by a loss function. In the typical setting of over-parametrization, the network width $m$ is much larger than the data dimension $d$ and the number of training samples $n$ ($m=\mathrm{poly}(n,d)$), which induces a prohibitive large weight matrix $W\in \mathbb{R}^{m\times m}$ per layer. Naively, one has to pay $O(m^2)$ time to read the weight matrix and evaluate the neural network function in both forward and backward computation. In this work, we show how to reduce the training cost per iteration. Specifically, we propose a framework that uses $m^2$ cost only in the initialization phase and achieves \emph{a truly subquadratic cost per iteration} in terms of $m$, i.e., $m^{2-Ω(1)}$ per iteration. Our result has implications beyond standard over-parametrization theory, as it can be viewed as designing an efficient data structure on top of a pre-trained large model to further speed up the fine-tuning process, a core procedure to deploy large language models (LLM).