Refined $α$-Divergence Variational Inference via Rejection Sampling
This work addresses the challenge of refining variational approximations in machine learning, offering an incremental improvement for probabilistic modeling tasks.
The paper tackles the problem of approximate inference by combining Rényi α-divergence variational inference with rejection sampling to improve accuracy, demonstrating through experiments that the method learns considerably more accurate approximations of target distributions compared to existing RDVI.
We present an approximate inference method, based on a synergistic combination of Rényi $α$-divergence variational inference (RDVI) and rejection sampling (RS). RDVI is based on minimization of Rényi $α$-divergence $D_α(p||q)$ between the true distribution $p(x)$ and a variational approximation $q(x)$; RS draws samples from a distribution $p(x) = \tilde{p}(x)/Z_{p}$ using a proposal $q(x)$, s.t. $Mq(x) \geq \tilde{p}(x), \forall x$. Our inference method is based on a crucial observation that $D_\infty(p||q)$ equals $\log M(θ)$ where $M(θ)$ is the optimal value of the RS constant for a given proposal $q_θ(x)$. This enables us to develop a \emph{two-stage} hybrid inference algorithm. Stage-1 performs RDVI to learn $q_θ$ by minimizing an estimator of $D_α(p||q)$, and uses the learned $q_θ$ to find an (approximately) optimal $\tilde{M}(θ)$. Stage-2 performs RS using the constant $\tilde{M}(θ)$ to improve the approximate distribution $q_θ$ and obtain a sample-based approximation. We prove that this two-stage method allows us to learn considerably more accurate approximations of the target distribution as compared to RDVI. We demonstrate our method's efficacy via several experiments on synthetic and real datasets.