Quadratic Gating Mixture of Experts: Statistical Insights into Self-Attention
This work provides theoretical insights into MoE and self-attention, leading to an improved attention mechanism for various AI tasks, though it is incremental in nature.
The paper establishes a connection between Mixture of Experts (MoE) and self-attention, showing that self-attention can be expressed as a quadratic gating MoE, and proposes an active-attention mechanism that applies a non-linear activation to the value matrix, which outperforms standard self-attention in tasks like image classification, language modeling, and time series forecasting.
Mixture of Experts (MoE) models are well known for effectively scaling model capacity while preserving computational overheads. In this paper, we establish a rigorous relation between MoE and the self-attention mechanism, showing that each row of a self-attention matrix can be written as a quadratic gating mixture of linear experts. Motivated by this connection, we conduct a comprehensive convergence analysis of MoE models with two different quadratic gating functions, namely the quadratic polynomial gate and the quadratic monomial gate, offering useful insights into the design of gating and experts for the MoE framework. First, our analysis indicates that the use of the quadratic monomial gate yields an improved sample efficiency for estimating parameters and experts compared to the quadratic polynomial gate. Second, parameter and expert estimation rates become significantly faster when employing non-linear experts in place of linear experts. Combining these theoretical insights with the above link between MoE and self-attention, we propose a novel \emph{active-attention} mechanism where we apply a non-linear activation function to the value matrix in the formula of self-attention. Finally, we demonstrate that the proposed active-attention outperforms the standard self-attention through several extensive experiments in various tasks, including image classification, language modeling, and multivariate time series forecasting.