Discretely Relaxing Continuous Variables for tractable Variational Inference
This work addresses the problem of scalable and hardware-efficient variational inference for researchers and practitioners in machine learning, though it is incremental as it builds on existing discrete latent variable priors with a novel computational approach.
The paper tackles the challenge of efficient Bayesian variational inference with discrete latent variables by introducing the DIRECT method, which uses Kronecker matrix algebra to compute exact ELBO gradients and reduce training complexity, enabling training on datasets with over two-million points in seconds and using 4-bit quantized integers for fast inference.
We explore a new research direction in Bayesian variational inference with discrete latent variable priors where we exploit Kronecker matrix algebra for efficient and exact computations of the evidence lower bound (ELBO). The proposed "DIRECT" approach has several advantages over its predecessors; (i) it can exactly compute ELBO gradients (i.e. unbiased, zero-variance gradient estimates), eliminating the need for high-variance stochastic gradient estimators and enabling the use of quasi-Newton optimization methods; (ii) its training complexity is independent of the number of training points, permitting inference on large datasets; and (iii) its posterior samples consist of sparse and low-precision quantized integers which permit fast inference on hardware limited devices. In addition, our DIRECT models can exactly compute statistical moments of the parameterized predictive posterior without relying on Monte Carlo sampling. The DIRECT approach is not practical for all likelihoods, however, we identify a popular model structure which is practical, and demonstrate accurate inference using latent variables discretized as extremely low-precision 4-bit quantized integers. While the ELBO computations considered in the numerical studies require over $10^{2352}$ log-likelihood evaluations, we train on datasets with over two-million points in just seconds.