Scaling Recurrent Neural Networks to a Billion Parameters with Zero-Order Optimization
This addresses a critical bottleneck for researchers and practitioners in AI/ML who need efficient training of large RNNs for long-context tasks like language modeling, though it is incremental as it builds on existing optimization techniques.
The paper tackles the problem of training large Recurrent Neural Networks (RNNs) on long contexts, which is impractical with standard Backpropagation Through Time (BPTT) due to high memory usage, by using Zero-Order Optimization (ZOO) methods like Random-vector Gradient Estimation (RGE) to replace BPTT, achieving convergence rates that match or exceed BPTT by up to 19 fold while using orders of magnitude less memory and cost.
During inference, Recurrent Neural Networks (RNNs) scale constant in both FLOPs and GPU memory with increasing context length, as they compress all prior tokens into a fixed-size memory. In contrast, transformers scale linearly in FLOPs and, at best, linearly in memory during generation, since they must attend to all previous tokens explicitly. Despite this inference-time advantage, training large RNNs on long contexts remains impractical because standard optimization methods depend on Backpropagation Through Time (BPTT). BPTT requires retention of all intermediate activations during the forward pass, causing memory usage to scale linearly with both context length and model size. In this paper, we show that Zero-Order Optimization (ZOO) methods such as Random-vector Gradient Estimation (RGE) can successfully replace BPTT to train RNNs with convergence rates that match, or exceed BPTT by up to 19 fold, while using orders of magnitude less memory and cost, as the model remains in inference mode throughout training. We further demonstrate that Central-Difference RGE (CD-RGE) corresponds to optimizing a smoothed surrogate loss, inherently regularizing training and improving generalization. Our method matches or outperforms BPTT across three settings: (1) overfitting, (2) transduction, and (3) language modeling. Across all tasks, with sufficient perturbations, our models generalize as well as or better than those trained with BPTT, often in fewer steps. Despite the need for more forward passes per step, we can surpass BPTT wall-clock time per step using recent advancements such as FlashRNN and distributed inference.