2-Wasserstein Approximation via Restricted Convex Potentials with Application to Improved Training for GANs
This work addresses the challenge of efficient and interpretable optimal transport approximation for machine learning practitioners, particularly in generative modeling, though it appears incremental by building on existing convex potential frameworks.
The paper tackles the problem of approximating the 2-Wasserstein distance and optimal transport map by restricting optimization to parametrized classes of convex functions, such as input-convex neural networks, and demonstrates its application to improved training for GANs through a modular design.
We provide a framework to approximate the 2-Wasserstein distance and the optimal transport map, amenable to efficient training as well as statistical and geometric analysis. With the quadratic cost and considering the Kantorovich dual form of the optimal transportation problem, the Brenier theorem states that the optimal potential function is convex and the optimal transport map is the gradient of the optimal potential function. Using this geometric structure, we restrict the optimization problem to different parametrized classes of convex functions and pay special attention to the class of input-convex neural networks. We analyze the statistical generalization and the discriminative power of the resulting approximate metric, and we prove a restricted moment-matching property for the approximate optimal map. Finally, we discuss a numerical algorithm to solve the restricted optimization problem and provide numerical experiments to illustrate and compare the proposed approach with the established regularization-based approaches. We further discuss practical implications of our proposal in a modular and interpretable design for GANs which connects the generator training with discriminator computations to allow for learning an overall composite generator.