Why Are Linear RNNs More Parallelizable?
This work addresses the problem of parallelizability in recurrent neural networks for the machine learning community, providing a foundation for designing large language model architectures that balance expressivity and parallelism.
The authors investigated the parallelizability of linear RNNs (LRNNs) and found that they can be viewed as log-depth arithmetic circuits, allowing for efficient parallelization, whereas nonlinear RNNs have a fundamental barrier to parallelization due to their ability to solve L-complete and P-complete problems. This results in a tradeoff between expressivity and parallelism, with LRNNs achieving a balance between the two.
The community is increasingly exploring linear RNNs (LRNNs) as language models, motivated by their expressive power and parallelizability. While prior work establishes the expressivity benefits of LRNNs over transformers, it is unclear what makes LRNNs -- but not traditional, nonlinear RNNs -- as easy to parallelize in practice as transformers. We answer this question by providing a tight connection between types of RNNs and standard complexity classes. We show that LRNNs can be viewed as log-depth (bounded fan-in) arithmetic circuits, which represents only a slight depth overhead relative to log-depth boolean circuits that transformers admit. Furthermore, we show that nonlinear RNNs can solve $\mathsf{L}$-complete problems (and even $\mathsf{P}$-complete ones, under polynomial precision), revealing a fundamental barrier to parallelizing them as efficiently as transformers. Our theory also identifies fine-grained expressivity differences between recent popular LRNN variants: permutation-diagonal LRNNs are $\mathsf{NC}^1$-complete whereas diagonal-plus-low-rank LRNNs are more expressive ($\mathsf{PNC}^1$-complete). We provide further insight by associating each type of RNN with a corresponding automata-theoretic model that it can simulate. Together, our results reveal fundamental tradeoffs between nonlinear RNNs and different variants of LRNNs, providing a foundation for designing LLM architectures that achieve an optimal balance between expressivity and parallelism.