MLLGMay 28, 2020

Joint Stochastic Approximation and Its Application to Learning Discrete Latent Variable Models

arXiv:2005.14001v111 citations
Originality Highly original
AI Analysis

This addresses the problem of unreliable gradients and indirect optimization in discrete latent variable models for machine learning practitioners, offering a novel approach with incremental improvements.

The paper tackles the challenge of learning discrete latent variable models by proposing Joint Stochastic Approximation (JSA), a method that directly maximizes log-likelihood and minimizes divergence between posterior and inference models, resulting in faster convergence, better likelihoods, and lower gradient variance on benchmark tasks.

Although with progress in introducing auxiliary amortized inference models, learning discrete latent variable models is still challenging. In this paper, we show that the annoying difficulty of obtaining reliable stochastic gradients for the inference model and the drawback of indirectly optimizing the target log-likelihood can be gracefully addressed in a new method based on stochastic approximation (SA) theory of the Robbins-Monro type. Specifically, we propose to directly maximize the target log-likelihood and simultaneously minimize the inclusive divergence between the posterior and the inference model. The resulting learning algorithm is called joint SA (JSA). To the best of our knowledge, JSA represents the first method that couples an SA version of the EM (expectation-maximization) algorithm (SAEM) with an adaptive MCMC procedure. Experiments on several benchmark generative modeling and structured prediction tasks show that JSA consistently outperforms recent competitive algorithms, with faster convergence, better final likelihoods, and lower variance of gradient estimates.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes