Scalable Computations of Wasserstein Barycenter via Input Convex Neural Networks
This provides a scalable solution for machine learning practitioners needing to compute weighted means of probability distributions, though it is incremental as it builds on existing neural network architectures.
The paper tackles the problem of computing Wasserstein barycenters for high-dimensional machine learning applications by proposing a scalable algorithm based on input convex neural networks, achieving competitive performance compared to state-of-the-art methods in experiments.
Wasserstein Barycenter is a principled approach to represent the weighted mean of a given set of probability distributions, utilizing the geometry induced by optimal transport. In this work, we present a novel scalable algorithm to approximate the Wasserstein Barycenters aiming at high-dimensional applications in machine learning. Our proposed algorithm is based on the Kantorovich dual formulation of the Wasserstein-2 distance as well as a recent neural network architecture, input convex neural network, that is known to parametrize convex functions. The distinguishing features of our method are: i) it only requires samples from the marginal distributions; ii) unlike the existing approaches, it represents the Barycenter with a generative model and can thus generate infinite samples from the barycenter without querying the marginal distributions; iii) it works similar to Generative Adversarial Model in one marginal case. We demonstrate the efficacy of our algorithm by comparing it with the state-of-art methods in multiple experiments.