Iterative Decoding for Compositional Generalization in Transformers
This addresses a key limitation in deep learning for sequence-to-sequence tasks, offering a method to enhance generalization, though it is incremental with noted limitations in other datasets.
The paper tackles the problem of compositional generalization in transformers, where models struggle to combine learned primitives for complex tasks, and introduces iterative decoding to improve performance on PCFG and Cartesian product datasets, showing concrete gains.
Deep learning models generalize well to in-distribution data but struggle to generalize compositionally, i.e., to combine a set of learned primitives to solve more complex tasks. In sequence-to-sequence (seq2seq) learning, transformers are often unable to predict correct outputs for longer examples than those seen at training. This paper introduces iterative decoding, an alternative to seq2seq that (i) improves transformer compositional generalization in the PCFG and Cartesian product datasets and (ii) evidences that, in these datasets, seq2seq transformers do not learn iterations that are not unrolled. In iterative decoding, training examples are broken down into a sequence of intermediate steps that the transformer learns iteratively. At inference time, the intermediate outputs are fed back to the transformer as intermediate inputs until an end-of-iteration token is predicted. We conclude by illustrating some limitations of iterative decoding in the CFQ dataset.