Fast-weight Product Key Memory
This addresses the problem of limited storage and high computational cost in language models for researchers and practitioners, offering a novel method that is not purely incremental.
The paper tackles the trade-off between storage capacity and computational efficiency in sequence modeling layers for language models by introducing Fast-weight Product Key Memory (FwPKM), a sparse fast-weight memory layer that enables rapid memorization and retrieval of key-value associations, resulting in significant perplexity reductions on long-context datasets and generalization to 128K-token contexts despite training on only 4K-token sequences.
Sequence modeling layers in modern language models typically face a trade-off between storage capacity and computational efficiency. While softmax attention offers unbounded storage at prohibitive quadratic cost, linear variants are more efficient but suffer from limited, fixed-size storage. We introduce Fast-weight Product Key Memory (FwPKM), a sparse fast-weight memory layer that resolves this tension. FwPKM updates sparsely activated parameters at both training and inference time using chunk-level gradient descent on a local memory-rewrite objective. This performs Test-Time Training (TTT)-style gradient updates on activated slots in a sparse memory, enabling rapid memorization and retrieval of many new key-value associations while keeping per-token compute low and fixed. Experiments show that FwPKM functions as an effective episodic memory that complements the semantic memory of standard modules, yielding significant perplexity reductions on long-context datasets. Notably, in Needle-in-a-Haystack evaluations, FwPKM generalizes to 128K-token contexts despite being trained on only 4K-token sequences.