Disentangling Feature Structure: A Mathematically Provable Two-Stage Training Dynamics in Transformers
This provides a foundational theoretical insight into transformer training dynamics, which is incremental as it builds on existing observations but offers the first rigorous analysis of feature-level stages.
The paper tackles the problem of explaining the two-stage training dynamics observed in transformers, such as GPT-2 on the Counterfact dataset, where learning progresses from syntactically incorrect to semantically correct outputs, by theoretically demonstrating how this phenomenon arises from disentangled feature structures like syntax and semantics.
Transformers may exhibit two-stage training dynamics during the real-world training process. For instance, when training GPT-2 on the Counterfact dataset, the answers progress from syntactically incorrect to syntactically correct to semantically correct. However, existing theoretical analyses hardly account for this feature-level two-stage phenomenon, which originates from the disentangled two-type features like syntax and semantics. In this paper, we theoretically demonstrate how the two-stage training dynamics potentially occur in transformers. Specifically, we analyze the feature learning dynamics induced by the aforementioned disentangled two-type feature structure, grounding our analysis in a simplified yet illustrative setting that comprises a normalized ReLU self-attention layer and structured data. Such disentanglement of feature structure is general in practice, e.g., natural languages contain syntax and semantics, and proteins contain primary and secondary structures. To our best knowledge, this is the first rigorous result regarding a feature-level two-stage optimization process in transformers. Additionally, a corollary indicates that such a two-stage process is closely related to the spectral properties of the attention weights, which accords well with our empirical findings.