On The Chain Rule Optimal Transport Distance
This work addresses the challenge of efficiently comparing and learning complex statistical distributions, which is incremental but offers practical gains for machine learning tasks like density estimation.
The authors tackled the problem of measuring distances between multivariate statistical distributions by introducing a new class of metrics based on optimal transport, which generalizes existing distances and provides upper bounds for jointly convex distances. They developed a fast, differentiable Sinkhorn-type variant and applied it to learn Gaussian mixture models, achieving significant improvements over standard EM methods on MNIST and Fashion MNIST datasets.
We define a novel class of distances between statistical multivariate distributions by modeling an optimal transport problem on their marginals with respect to a ground distance defined on their conditionals. These new distances are metrics whenever the ground distance between the marginals is a metric, generalize both the Wasserstein distances between discrete measures and a recently introduced metric distance between statistical mixtures, and provide an upper bound for jointly convex distances between statistical mixtures. By entropic regularization of the optimal transport, we obtain a fast differentiable Sinkhorn-type distance. We experimentally evaluate our new family of distances by quantifying the upper bounds of several jointly convex distances between statistical mixtures, and by proposing a novel efficient method to learn Gaussian mixture models (GMMs) by simplifying kernel density estimators with respect to our distance. Our GMM learning technique experimentally improves significantly over the EM implementation of {\tt sklearn} on the {\tt MNIST} and {\tt Fashion MNIST} datasets.