Variational Inference in Nonconjugate Models
This work addresses a bottleneck for practitioners in machine learning and statistics who need efficient approximate posterior inference in complex, nonconjugate models, offering a more unified and accessible approach.
The paper tackles the problem of applying mean-field variational inference to nonconjugate probabilistic models, which previously required case-by-case algorithm development, by introducing two generic methods—Laplace variational inference and delta method variational inference—that enable easily derived algorithms and perform well on real-world datasets like the correlated topic model and Bayesian logistic regression.
Mean-field variational methods are widely used for approximate posterior inference in many probabilistic models. In a typical application, mean-field methods approximately compute the posterior with a coordinate-ascent optimization algorithm. When the model is conditionally conjugate, the coordinate updates are easily derived and in closed form. However, many models of interest---like the correlated topic model and Bayesian logistic regression---are nonconjuate. In these models, mean-field methods cannot be directly applied and practitioners have had to develop variational algorithms on a case-by-case basis. In this paper, we develop two generic methods for nonconjugate models, Laplace variational inference and delta method variational inference. Our methods have several advantages: they allow for easily derived variational algorithms with a wide class of nonconjugate models; they extend and unify some of the existing algorithms that have been derived for specific models; and they work well on real-world datasets. We studied our methods on the correlated topic model, Bayesian logistic regression, and hierarchical Bayesian logistic regression.