MLLGCOMEMay 4, 2019

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables

arXiv:1905.01413v224 citations
AI Analysis

This addresses a key bottleneck in machine learning for tasks involving categorical variables, such as reinforcement learning and generative models, offering a novel solution with practical improvements.

The paper tackles the challenge of gradient backpropagation through categorical variables by proposing the ARSM estimator, which is unbiased and achieves low variance, showing it closely matches true gradient performance in univariate settings and outperforms existing methods in categorical variational auto-encoders.

To address the challenge of backpropagating the gradient through categorical variables, we propose the augment-REINFORCE-swap-merge (ARSM) gradient estimator that is unbiased and has low variance. ARSM first uses variable augmentation, REINFORCE, and Rao-Blackwellization to re-express the gradient as an expectation under the Dirichlet distribution, then uses variable swapping to construct differently expressed but equivalent expectations, and finally shares common random numbers between these expectations to achieve significant variance reduction. Experimental results show ARSM closely resembles the performance of the true gradient for optimization in univariate settings; outperforms existing estimators by a large margin when applied to categorical variational auto-encoders; and provides a "try-and-see self-critic" variance reduction method for discrete-action policy gradient, which removes the need of estimating baselines by generating a random number of pseudo actions and estimating their action-value functions.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes