Constrained Sliced Wasserstein Embedding
This work addresses a computational bottleneck in machine learning for researchers and practitioners using Sliced Wasserstein distances, offering an incremental improvement in efficiency and performance.
The paper tackles the challenge of identifying informative slicing directions for Sliced Wasserstein distances, which are used to compare high-dimensional probability measures, by introducing a constrained learning approach that optimizes these directions and demonstrates efficacy on foundation models across images, point clouds, and protein sequences.
Sliced Wasserstein (SW) distances offer an efficient method for comparing high-dimensional probability measures by projecting them onto multiple 1-dimensional probability distributions. However, identifying informative slicing directions has proven challenging, often necessitating a large number of slices to achieve desirable performance and thereby increasing computational complexity. We introduce a constrained learning approach to optimize the slicing directions for SW distances. Specifically, we constrain the 1D transport plans to approximate the optimal plan in the original space, ensuring meaningful slicing directions. By leveraging continuous relaxations of these transport plans, we enable a gradient-based primal-dual approach to train the slicer parameters, alongside the remaining model parameters. We demonstrate how this constrained slicing approach can be applied to pool high-dimensional embeddings into fixed-length permutation-invariant representations. Numerical results on foundation models trained on images, point clouds, and protein sequences showcase the efficacy of the proposed constrained learning approach in learning more informative slicing directions. Our implementation code can be found at https://github.com/Stranja572/constrainedswe.