LGSYMLJul 28, 2022

ORFit: One-Pass Learning via Bridging Orthogonal Gradient Descent and Recursive Least-Squares

MIT
arXiv:2207.13853v324 citationsh-index: 17
Originality Highly original
AI Analysis

This addresses the challenge of training large models under computational, memory, and privacy constraints in streaming scenarios, offering a practical solution with theoretical guarantees.

The paper tackles the problem of one-pass learning for streaming data by proposing ORFit, an algorithm that updates parameters orthogonally to past gradients, achieving linear memory and computational complexity in the number of parameters and matching the convergence of multi-pass SGD for overparameterized linear models.

While large machine learning models have shown remarkable performance in various domains, their training typically requires iterating for many passes over the training data. However, due to computational and memory constraints and potential privacy concerns, storing and accessing all the data is impractical in many real-world scenarios where the data arrives in a stream. In this paper, we investigate the problem of one-pass learning, in which a model is trained on sequentially arriving data without retraining on previous datapoints. Motivated by the demonstrated effectiveness of overparameterized models and the phenomenon of benign overfitting, we propose Orthogonal Recursive Fitting (ORFit), an algorithm for one-pass learning which seeks to perfectly fit each new datapoint while minimally altering the predictions on previous datapoints. ORFit updates the parameters in a direction orthogonal to past gradients, similar to orthogonal gradient descent (OGD) in continual learning. We show that, interestingly, ORFit's update leads to an operation similar to the recursive least-squares (RLS) algorithm in adaptive filtering but with significantly improved memory and computational efficiency, i.e., linear, instead of quadratic, in the number of parameters. To further reduce memory usage, we leverage the structure of the streaming data via an incremental principal component analysis (IPCA). We show that using the principal components is minimax optimal, i.e., it minimizes the worst-case forgetting of previous predictions for unknown future updates. Further, we prove that, for overparameterized linear models, the parameter vector obtained by ORFit matches what the standard multi-pass stochastic gradient descent (SGD) would converge to. Finally, we extend our results to the nonlinear setting for highly overparameterized models, relevant for deep learning.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes