Differentiable Random Access Memory using Lattices
This addresses the computational bottleneck of scaling model capacity for large language modeling, though it appears incremental as an augmentation to existing architectures.
The paper tackles the problem of scaling neural network capacity efficiently by introducing a differentiable random access memory module with O(1) performance that scales to billions of entries, achieving better accuracy than transformer baselines on large language modeling tasks with continued scaling up to tested limits.
We introduce a differentiable random access memory module with $O(1)$ performance regardless of size, scaling to billions of entries. The design stores entries on points of a chosen lattice to calculate nearest neighbours of arbitrary points efficiently by exploiting symmetries. Augmenting a standard neural network architecture with a single memory layer based on this, we can scale the parameter count up to memory limits with negligible computational overhead, giving better accuracy at similar cost. On large language modelling tasks, these enhanced models with larger capacity significantly outperform the unmodified transformer baseline. We found continued scaling with memory size up to the limits tested.