Understanding the Staged Dynamics of Transformers in Learning Latent Structure
This work provides incremental insights into the training dynamics of transformers for researchers in machine learning and AI, focusing on latent structure learning.
The paper investigates how small decoder-only transformers learn latent structure from context using the Alchemy benchmark, revealing that acquisition occurs in discrete stages, starting with coarse-grained rules before mastering the full structure, and identifying an asymmetry where composition is robust but decomposition is challenging.
While transformers can discover latent structure from context, the dynamics of how they acquire different components of the latent structure remain poorly understood. In this work, we use the Alchemy benchmark, to investigate the dynamics of latent structure learning. We train a small decoder-only transformer on three task variants: 1) inferring missing rules from partial contextual information, 2) composing simple rules to solve multi-step sequences, and 3) decomposing complex multi-step examples to infer intermediate steps. By factorizing each task into interpretable events, we show that the model acquires capabilities in discrete stages, first learning the coarse grained rules, before learning the complete latent structure. We also identify a crucial asymmetry, where the model can compose fundamental rules robustly, but struggles to decompose complex examples to discover the fundamental rules. These findings offer new insights into understanding how a transformer model learns latent structures, providing a granular view of how these capabilities evolve during training.