Masked Mixers for Language Generation and Retrieval
This work addresses a potential bottleneck in language modeling and retrieval for AI researchers, offering a novel alternative to attention mechanisms with demonstrated gains in specific scenarios.
The paper tackles the problem of information loss in attention-based language models by introducing masked mixers, which replace self-attention with masked convolutions to improve input representation accuracy. It shows that masked mixers learn causal language modeling more efficiently than transformers in small contexts and outperform a large transformer-based retrieval model in retrieval tasks, despite using less data and compute.
Attention mechanisms that confer selective focus on a strict subset of input elements are nearly ubiquitous in language models today. We posit there to be downside to the use of attention: most input information is lost. In support of this idea we observe poor input representation accuracy in transformers and more accurate representation in what we term masked mixers, which replace self-attention with masked convolutions. The masked mixer learns causal language modeling more efficiently than early transformer implementations and even outperforms optimized, current transformers when training on small ($n_{ctx}<512$) but not larger context windows. Evidence is presented for the hypothesis that differences in transformer and masked mixer training efficiencies for various tasks are best predicted by input representation accuracy, or equivalently global invertibility. We hypothesize that the information loss exhibited by transformers would be more detrimental to retrieval than generation, as the former is more closely approximated by a bijective and thus invertible function. We find that masked mixers are more effective retrieval models both when the pretrained embedding model is unchanged as well as when the embedding model is modified via cosine similarity-based InfoNCE loss minimization. A small masked mixer is shown to outperform a large and near state-of-the-art transformer-based retrieval model, despite the latter being trained with many orders of magnitude more data and compute.