When can transformers reason with abstract symbols?
This addresses a fundamental challenge in AI for enabling robust reasoning in models, though it is incremental as it builds on existing transformer architectures.
The paper tackles the problem of whether transformer models can perform relational reasoning with abstract symbols on out-of-distribution data, proving that they can learn and generalize effectively with sufficient training data, unlike classical fully-connected networks which fail.
We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by gradient descent on sufficiently large quantities of training data. This is in contrast to classical fully-connected networks, which we prove fail to learn to reason. Our results inspire modifications of the transformer architecture that add only two trainable parameters per head, and that we empirically demonstrate improve data efficiency for learning to reason.