Layerwise Proximal Replay: A Proximal Point Method for Online Continual Learning
This addresses a key optimization issue in online continual learning for AI systems that learn from streaming data, though it is incremental as it modifies existing replay-based methods.
The paper tackles the problem of unstable optimization trajectories in online continual learning with experience replay, which impedes accuracy even when all past data is stored. It introduces Layerwise Proximal Replay (LPR), a method that balances learning from new and replay data by allowing gradual changes in hidden activations, consistently improving performance across various settings.
In online continual learning, a neural network incrementally learns from a non-i.i.d. data stream. Nearly all online continual learning methods employ experience replay to simultaneously prevent catastrophic forgetting and underfitting on past data. Our work demonstrates a limitation of this approach: neural networks trained with experience replay tend to have unstable optimization trajectories, impeding their overall accuracy. Surprisingly, these instabilities persist even when the replay buffer stores all previous training examples, suggesting that this issue is orthogonal to catastrophic forgetting. We minimize these instabilities through a simple modification of the optimization geometry. Our solution, Layerwise Proximal Replay (LPR), balances learning from new and replay data while only allowing for gradual changes in the hidden activation of past data. We demonstrate that LPR consistently improves replay-based online continual learning methods across multiple problem settings, regardless of the amount of available replay memory.