Beyond Softmax: A Natural Parameterization for Categorical Random Variables
This addresses a fundamental bottleneck in deep learning for researchers and practitioners using categorical variables, offering a simple, compatible alternative to softmax, though it is incremental as it builds on existing gradient estimation techniques.
The paper tackles the challenge of gradient-descent learning with latent categorical variables by replacing the softmax function with a hierarchical binary-split function called catnat, which improves learning efficiency and yields consistently higher test performance across tasks like graph structure learning, variational autoencoders, and reinforcement learning.
Latent categorical variables are frequently found in deep learning architectures. They can model actions in discrete reinforcement-learning environments, represent categories in latent-variable models, or express relations in graph neural networks. Despite their widespread use, their discrete nature poses significant challenges to gradient-descent learning algorithms. While a substantial body of work has offered improved gradient estimation techniques, we take a complementary approach. Specifically, we: 1) revisit the ubiquitous $\textit{softmax}$ function and demonstrate its limitations from an information-geometric perspective; 2) replace the $\textit{softmax}$ with the $\textit{catnat}$ function, a function composed of a sequence of hierarchical binary splits; we prove that this choice offers significant advantages to gradient descent due to the resulting diagonal Fisher Information Matrix. A rich set of experiments - including graph structure learning, variational autoencoders, and reinforcement learning - empirically show that the proposed function improves the learning efficiency and yields models characterized by consistently higher test performance. $\textit{Catnat}$ is simple to implement and seamlessly integrates into existing codebases. Moreover, it remains compatible with standard training stabilization techniques and, as such, offers a better alternative to the $\textit{softmax}$ function.