LGAIMLFeb 28, 2024

Learning Associative Memories with Gradient Descent

arXiv:2402.18724v115 citationsh-index: 20ICML
Originality Incremental advance
AI Analysis

This work addresses the problem of understanding and optimizing associative memory training for machine learning practitioners, though it is incremental as it builds on existing theories of gradient descent and memory modules.

The paper investigates the training dynamics of associative memory modules using gradient descent, revealing that overparameterized regimes yield logarithmic growth of classification margins but suffer from oscillatory transients due to token imbalances and correlated embeddings, while underparameterized regimes lead to suboptimal memorization with cross-entropy loss, validated on small Transformer models.

This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the ``classification margins.'' Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes