Yikun Bai

LG
h-index31
21papers
189citations
Novelty52%
AI Score56

21 Papers

LGAug 24, 2022
Wasserstein Task Embedding for Measuring Task Similarities

Xinran Liu, Yikun Bai, Yuzhe Lu et al.

Measuring similarities between different tasks is critical in a broad spectrum of machine learning problems, including transfer, multi-task, continual, and meta-learning. Most current approaches to measuring task similarities are architecture-dependent: 1) relying on pre-trained models, or 2) training networks on tasks and using forward transfer as a proxy for task similarity. In this paper, we leverage the optimal transport theory and define a novel task embedding for supervised classification that is model-agnostic, training-free, and capable of handling (partially) disjoint label sets. In short, given a dataset with ground-truth labels, we perform a label embedding through multi-dimensional scaling and concatenate dataset samples with their corresponding label embeddings. Then, we define the distance between two datasets as the 2-Wasserstein distance between their updated samples. Lastly, we leverage the 2-Wasserstein embedding framework to embed tasks into a vector space in which the Euclidean distance between the embedded points approximates the proposed 2-Wasserstein distance between tasks. We show that the proposed embedding leads to a significantly faster comparison of tasks compared to related approaches like the Optimal Transport Dataset Distance (OTDD). Furthermore, we demonstrate the effectiveness of our proposed embedding through various numerical experiments and show statistically significant correlations between our proposed distance and the forward and backward transfer between tasks.

LGDec 15, 2022
Sliced Optimal Partial Transport

Yikun Bai, Berhnard Schmitzer, Mathew Thorpe et al.

Optimal transport (OT) has become exceedingly popular in machine learning, data science, and computer vision. The core assumption in the OT problem is the equal total amount of mass in source and target measures, which limits its application. Optimal Partial Transport (OPT) is a recently proposed solution to this limitation. Similar to the OT problem, the computation of OPT relies on solving a linear programming problem (often in high dimensions), which can become computationally prohibitive. In this paper, we propose an efficient algorithm for calculating the OPT problem between two non-negative measures in one dimension. Next, following the idea of sliced OT distances, we utilize slicing to define the sliced OPT distance. Finally, we demonstrate the computational and accuracy benefits of the sliced OPT-based method in various numerical experiments. In particular, we show an application of our proposed Sliced-OPT in noisy point cloud registration.

LGFeb 7, 2023
Linear Optimal Partial Transport Embedding

Yikun Bai, Ivan Medri, Rocio Diaz Martin et al.

Optimal transport (OT) has gained popularity due to its various applications in fields such as machine learning, statistics, and signal processing. However, the balanced mass requirement limits its performance in practical problems. To address these limitations, variants of the OT problem, including unbalanced OT, Optimal partial transport (OPT), and Hellinger Kantorovich (HK), have been proposed. In this paper, we propose the Linear optimal partial transport (LOPT) embedding, which extends the (local) linearization technique on OT and HK to the OPT problem. The proposed embedding allows for faster computation of OPT distance between pairs of positive measures. Besides our theoretical contributions, we demonstrate the LOPT embedding technique in point-cloud interpolation and PCA analysis.

CVSep 27, 2023
Partial Transport for Point-Cloud Registration

Yikun Bai, Huy Tran, Steven B. Damelin et al.

Point cloud registration plays a crucial role in various fields, including robotics, computer graphics, and medical imaging. This process involves determining spatial relationships between different sets of points, typically within a 3D space. In real-world scenarios, complexities arise from non-rigid movements and partial visibility, such as occlusions or sensor noise, making non-rigid registration a challenging problem. Classic non-rigid registration methods are often computationally demanding, suffer from unstable performance, and, importantly, have limited theoretical guarantees. The optimal transport problem and its unbalanced variations (e.g., the optimal partial transport problem) have emerged as powerful tools for point-cloud registration, establishing a strong benchmark in this field. These methods view point clouds as empirical measures and provide a mathematically rigorous way to quantify the `correspondence' between (the transformed) source and target points. In this paper, we approach the point-cloud registration problem through the lens of optimal transport theory and first propose a comprehensive set of non-rigid registration methods based on the optimal partial transportation problem. Subsequently, leveraging the emerging work on efficient solutions to the one-dimensional optimal partial transport problem, we extend our proposed algorithms via slicing to gain significant computational efficiency, resulting in fast and robust non-rigid registration algorithms. We demonstrate the effectiveness of our proposed methods and compare them against baselines on various 3D and 2D non-rigid registration problems where the source and target point clouds are corrupted by random noise.

LGJul 25, 2023
PT$\mathrm{L}^{p}$: Partial Transport $\mathrm{L}^{p}$ Distances

Xinran Liu, Yikun Bai, Huy Tran et al.

Optimal transport and its related problems, including optimal partial transport, have proven to be valuable tools in machine learning for computing meaningful distances between probability or positive measures. This success has led to a growing interest in defining transport-based distances that allow for comparing signed measures and, more generally, multi-channeled signals. Transport $\mathrm{L}^{p}$ distances are notable extensions of the optimal transport framework to signed and possibly multi-channeled signals. In this paper, we introduce partial transport $\mathrm{L}^{p}$ distances as a new family of metrics for comparing generic signals, benefiting from the robustness of partial transport distances. We provide theoretical background such as the existence of optimal plans and the behavior of the distance in various limits. Furthermore, we introduce the sliced variation of these distances, which allows for rapid comparison of generic signals. Finally, we demonstrate the application of the proposed distances in signal class separability and nearest neighbor classification.

LGOct 9, 2023
LCOT: Linear circular optimal transport

Rocio Diaz Martin, Ivan Medri, Yikun Bai et al.

The optimal transport problem for measures supported on non-Euclidean spaces has recently gained ample interest in diverse applications involving representation learning. In this paper, we focus on circular probability measures, i.e., probability measures supported on the unit circle, and introduce a new computationally efficient metric for these measures, denoted as Linear Circular Optimal Transport (LCOT). The proposed metric comes with an explicit linear embedding that allows one to apply Machine Learning (ML) algorithms to the embedded measures and seamlessly modify the underlying metric for the ML algorithm to LCOT. We show that the proposed metric is rooted in the Circular Optimal Transport (COT) and can be considered the linearization of the COT metric with respect to a fixed reference measure. We provide a theoretical analysis of the proposed metric and derive the computational complexities for pairwise comparison of circular probability measures. Lastly, through a set of numerical experiments, we demonstrate the benefits of LCOT in learning representations of circular measures.

LGDec 8, 2025
LUNA: Linear Universal Neural Attention with Generalization Guarantees

Ashkan Shahbazi, Ping He, Ali Abbasi et al.

Scaling attention faces a critical bottleneck: the $\mathcal{O}(n^2)$ quadratic computational cost of softmax attention, which limits its application in long-sequence domains. While linear attention mechanisms reduce this cost to $\mathcal{O}(n)$, they typically rely on fixed random feature maps, such as random Fourier features or hand-crafted functions. This reliance on static, data-agnostic kernels creates a fundamental trade-off, forcing practitioners to sacrifice significant model accuracy for computational efficiency. We introduce \textsc{LUNA}, a kernelized linear attention mechanism that eliminates this trade-off, retaining linear cost while matching and surpassing the accuracy of quadratic attention. \textsc{LUNA} is built on the key insight that the kernel feature map itself should be learned rather than fixed a priori. By parameterizing the kernel, \textsc{LUNA} learns a feature basis tailored to the specific data and task, overcoming the expressive limitations of fixed-feature methods. \textsc{Luna} implements this with a learnable feature map that induces a positive-definite kernel and admits a streaming form, yielding linear time and memory scaling in the sequence length. Empirical evaluations validate our approach across diverse settings. On the Long Range Arena (LRA), \textsc{Luna} achieves state-of-the-art average accuracy among efficient Transformers under compute parity, using the same parameter count, training steps, and approximate FLOPs. \textsc{Luna} also excels at post-hoc conversion: replacing softmax in fine-tuned BERT and ViT-B/16 checkpoints and briefly fine-tuning recovers most of the original performance, substantially outperforming fixed linearizations.

LGFeb 4, 2024Code
Stereographic Spherical Sliced Wasserstein Distances

Huy Tran, Yikun Bai, Abihith Kothapalli et al.

Comparing spherical probability distributions is of great interest in various fields, including geology, medical domains, computer vision, and deep representation learning. The utility of optimal transport-based distances, such as the Wasserstein distance, for comparing probability measures has spurred active research in developing computationally efficient variations of these distances for spherical probability measures. This paper introduces a high-speed and highly parallelizable distance for comparing spherical measures using the stereographic projection and the generalized Radon transform, which we refer to as the Stereographic Spherical Sliced Wasserstein (S3W) distance. We carefully address the distance distortion caused by the stereographic projection and provide an extensive theoretical analysis of our proposed metric and its rotationally invariant variation. Finally, we evaluate the performance of the proposed metrics and compare them with recent baselines in terms of both speed and accuracy through a wide range of numerical studies, including gradient flows and self-supervised learning. Our code is available at https://github.com/mint-vu/s3wd.

LGFeb 6, 2024Code
Partial Gromov-Wasserstein Metric

Yikun Bai, Rocio Diaz Martin, Abihith Kothapalli et al.

The Gromov-Wasserstein (GW) distance has gained increasing interest in the machine learning community in recent years, as it allows for the comparison of measures in different metric spaces. To overcome the limitations imposed by the equal mass requirements of the classical GW problem, researchers have begun exploring its application in unbalanced settings. However, Unbalanced GW (UGW) can only be regarded as a discrepancy rather than a rigorous metric/distance between two metric measure spaces (mm-spaces). In this paper, we propose a particular case of the UGW problem, termed Partial Gromov-Wasserstein (PGW). We establish that PGW is a well-defined metric between mm-spaces and discuss its theoretical properties, including the existence of a minimizer for the PGW problem and the relationship between PGW and GW, among others. We then propose two variants of the Frank-Wolfe algorithm for solving the PGW problem and show that they are mathematically and computationally equivalent. Moreover, based on our PGW metric, we introduce the analogous concept of barycenters for mm-spaces. Finally, we validate the effectiveness of our PGW metric and related solvers in applications such as shape matching, shape retrieval, and shape interpolation, comparing them against existing baselines. Our code is available at https://github.com/mint-vu/PGW_Metric.

LGOct 22, 2024Code
Linear Partial Gromov-Wasserstein Embedding

Yikun Bai, Abihith Kothapalli, Hengrong Du et al.

The Gromov-Wasserstein (GW) problem, a variant of the classical optimal transport (OT) problem, has attracted growing interest in the machine learning and data science communities due to its ability to quantify similarity between measures in different metric spaces. However, like the classical OT problem, GW imposes an equal mass constraint between measures, which restricts its application in many machine learning tasks. To address this limitation, the partial Gromov-Wasserstein (PGW) problem has been introduced. It relaxes the equal mass constraint, allowing the comparison of general positive Radon measures. Despite this, both GW and PGW face significant computational challenges due to their non-convex nature. To overcome these challenges, we propose the linear partial Gromov-Wasserstein (LPGW) embedding, a linearized embedding technique for the PGW problem. For $K$ different metric measure spaces, the pairwise computation of the PGW distance requires solving the PGW problem ${O}(K^2)$ times. In contrast, the proposed linearization technique reduces this to ${O}(K)$ times. Similar to the linearization technique for the classical OT problem, we prove that LPGW defines a valid metric for metric measure spaces. Finally, we demonstrate the effectiveness of LPGW in practical applications such as shape retrieval and learning with transport-based embeddings, showing that LPGW preserves the advantages of PGW in partial matching while significantly enhancing computational efficiency. The code is available at https://github.com/mint-vu/Linearized_Partial_Gromov_Wasserstein.

LGFeb 13
Generalized Discrete Diffusion with Self-Correction

Linxuan Wang, Ziyi Wang, Yikun Bai et al.

Self-correction is an effective technique for maintaining parallel sampling in discrete diffusion models with minimal performance degradation. Prior work has explored self-correction at inference time or during post-training; however, such approaches often suffer from limited generalization and may impair reasoning performance. GIDD pioneers pretraining-based self-correction via a multi-step BERT-style uniform-absorbing objective. However, GIDD relies on a continuous interpolation-based pipeline with opaque interactions between uniform transitions and absorbing masks, which complicates hyperparameter tuning and hinders practical performance. In this work, we propose a Self-Correcting Discrete Diffusion (SCDD) model to reformulate pretrained self-correction with explicit state transitions and learn directly in discrete time. Our framework also simplifies the training noise schedule, eliminates a redundant remasking step, and relies exclusively on uniform transitions to learn self-correction. Experiments at the GPT-2 scale demonstrate that our method enables more efficient parallel decoding while preserving generation quality.

OCJul 9, 2024
Sinkhorn algorithms and linear programming solvers for optimal partial transport problems

Yikun Bai

In this note, we generalize the classical optimal partial transport (OPT) problem by modifying the mass destruction/creation term to function-based terms, introducing what we term ``generalized optimal partial transport'' problems. We then discuss the dual formulation of these problems and the associated Sinkhorn solver. Finally, we explore how these new OPT problems relate to classical optimal transport (OT) problems and introduce a linear programming solver tailored for these generalized scenarios.

LGMar 12
Sinkhorn-Drifting Generative Models

Ping He, Om Khangaonkar, Hamed Pirsiavash et al.

We establish a theoretical link between the recently proposed "drifting" generative dynamics and gradient flows induced by the Sinkhorn divergence. In a particle discretization, the drift field admits a cross-minus-self decomposition: an attractive term toward the target distribution and a repulsive/self-correction term toward the current model, both expressed via one-sided normalized Gibbs kernels. We show that Sinkhorn divergence yields an analogous cross-minus-self structure, but with each term defined by entropic optimal-transport couplings obtained through two-sided Sinkhorn scaling (i.e., enforcing both marginals). This provides a precise sense in which drifting acts as a surrogate for a Sinkhorn-divergence gradient flow, interpolating between one-sided normalization and full two-sided Sinkhorn scaling. Crucially, this connection resolves an identifiability gap in prior drifting formulations: leveraging the definiteness of the Sinkhorn divergence, we show that zero drift (equilibrium of the dynamics) implies that the model and target measures match. Experiments show that Sinkhorn drifting reduces sensitivity to kernel temperature and improves one-step generative quality, trading off additional training time for a more stable optimization, without altering the inference procedure used by drift methods. These theoretical gains translate to strong low-temperature improvements in practice: on FFHQ-ALAE at the lowest temperature setting we evaluate, Sinkhorn drifting reduces mean FID from 187.7 to 37.1 and mean latent EMD from 453.3 to 144.4, while on MNIST it preserves full class coverage across the temperature sweep. Project page: https://mint-vu.github.io/SinkhornDrifting/

LGOct 16, 2024
Expected Sliced Transport Plans

Xinran Liu, Rocío Díaz Martín, Yikun Bai et al.

The optimal transport (OT) problem has gained significant traction in modern machine learning for its ability to: (1) provide versatile metrics, such as Wasserstein distances and their variants, and (2) determine optimal couplings between probability measures. To reduce the computational complexity of OT solvers, methods like entropic regularization and sliced optimal transport have been proposed. The sliced OT framework improves efficiency by comparing one-dimensional projections (slices) of high-dimensional distributions. However, despite their computational efficiency, sliced-Wasserstein approaches lack a transportation plan between the input measures, limiting their use in scenarios requiring explicit coupling. In this paper, we address two key questions: Can a transportation plan be constructed between two probability measures using the sliced transport framework? If so, can this plan be used to define a metric between the measures? We propose a "lifting" operation to extend one-dimensional optimal transport plans back to the original space of the measures. By computing the expectation of these lifted plans, we derive a new transportation plan, termed expected sliced transport (EST) plans. We prove that using the EST plan to weight the sum of the individual Euclidean costs for moving from one point to another results in a valid metric between the input discrete probability measures. We demonstrate the connection between our approach and the recently proposed min-SWGG, along with illustrative numerical examples that support our theoretical findings.

LGFeb 14, 2025
Fused Partial Gromov-Wasserstein for Structured Objects

Yikun Bai, Shuang Wang, Huy Tran et al.

Structured data, such as graphs, is vital in machine learning due to its capacity to capture complex relationships and interactions. In recent years, the Fused Gromov-Wasserstein (FGW) distance has attracted growing interest because it enables the comparison of structured data by jointly accounting for feature similarity and geometric structure. However, as a variant of optimal transport (OT), classical FGW assumes an equal mass constraint on the compared data. In this work, we relax this mass constraint and propose the Fused Partial Gromov-Wasserstein (FPGW) framework, which extends FGW to accommodate unbalanced data. Theoretically, we establish the relationship between FPGW and FGW and prove the metric properties of FPGW. Numerically, we introduce Frank-Wolfe solvers and Sinkhorn solvers for the proposed FPGW framework. Finally, we evaluate the FPGW distance through graph matching, graph classification and graph clustering experiments, demonstrating its robust performance.

LGNov 9, 2024
Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data

Xinran Liu, Yikun Bai, Rocío Díaz Martín et al.

Efficient comparison of spherical probability distributions becomes important in fields such as computer vision, geosciences, and medicine. Sliced optimal transport distances, such as spherical and stereographic spherical sliced Wasserstein distances, have recently been developed to address this need. These methods reduce the computational burden of optimal transport by slicing hyperspheres into one-dimensional projections, i.e., lines or circles. Concurrently, linear optimal transport has been proposed to embed distributions into \( L^2 \) spaces, where the \( L^2 \) distance approximates the optimal transport distance, thereby simplifying comparisons across multiple distributions. In this work, we introduce the Linear Spherical Sliced Optimal Transport (LSSOT) framework, which utilizes slicing to embed spherical distributions into \( L^2 \) spaces while preserving their intrinsic geometry, offering a computationally efficient metric for spherical probability measures. We establish the metricity of LSSOT and demonstrate its superior computational efficiency in applications such as cortical surface registration, 3D point cloud interpolation via gradient flow, and shape embedding. Our results demonstrate the significant computational benefits and high accuracy of LSSOT in these applications.

LGSep 27, 2025
LOTFormer: Doubly-Stochastic Linear Attention via Low-Rank Optimal Transport

Ashkan Shahbazi, Chayne Thrash, Yikun Bai et al.

Transformers have proven highly effective across a wide range of modalities. However, the quadratic complexity of the standard softmax attention mechanism poses a fundamental barrier to scaling them to long context windows. A large body of work addresses this with linear attention, which reformulates attention as a kernel function and approximates it with finite feature maps to achieve linear-time computation. Orthogonal to computational scaling, most attention mechanisms -- both quadratic and linear -- produce row-normalized maps that can over-focus on a few tokens, degrading robustness and information flow. Enforcing doubly-stochastic attention alleviates this by balancing token participation across rows and columns, but existing doubly-stochastic attention mechanisms typically introduce substantial overhead, undermining scalability. We propose LOTFormer, a principled attention mechanism that is simultaneously linear-time and doubly-stochastic. Our approach exploits the connection between attention maps and transportation plans between query and key measures. The central idea is to constrain the transport plan to be low-rank by conditioning it on a learnable pivot measure with small support. Concretely, we solve two entropic optimal transport problems (queries $\to$ pivot and pivot $\to$ keys) and compose them into a conditional (glued) coupling. This yields an attention matrix that is provably doubly-stochastic, has rank at most $r \ll n$, and applies to values in $O(nr)$ time without forming the full $n \times n$ map. The pivot locations and masses are learned end-to-end. Empirically, LOTFormer achieves state-of-the-art results on the Long Range Arena benchmark, surpassing prior linear and transport-based attention methods in both accuracy and efficiency.

LGNov 16, 2024
Understanding Learning with Sliced-Wasserstein Requires Rethinking Informative Slices

Huy Tran, Yikun Bai, Ashkan Shahbazi et al.

The practical applications of Wasserstein distances (WDs) are constrained by their sample and computational complexities. Sliced-Wasserstein distances (SWDs) provide a workaround by projecting distributions onto one-dimensional subspaces, leveraging the more efficient, closed-form WDs for one-dimensional distributions. However, in high dimensions, most random projections become uninformative due to the concentration of measure phenomenon. Although several SWD variants have been proposed to focus on \textit{informative} slices, they often introduce additional complexity, numerical instability, and compromise desirable theoretical (metric) properties of SWD. Amidst the growing literature that focuses on directly modifying the slicing distribution, which often face challenges, we revisit the classical Sliced-Wasserstein and propose instead to rescale the 1D Wasserstein to make all slices equally informative. Importantly, we show that with an appropriate data assumption and notion of \textit{slice informativeness}, rescaling for all individual slices simplifies to \textbf{a single global scaling factor} on the SWD. This, in turn, translates to the standard learning rate search for gradient-based learning in common machine learning workflows. We perform extensive experiments across various machine learning tasks showing that the classical SWD, when properly configured, can often match or surpass the performance of more complex variants. We then answer the following question: "Is Sliced-Wasserstein all you need for common learning tasks?"

LGSep 26, 2025
Transport Based Mean Flows for Generative Modeling

Elaheh Akbari, Ping He, Ahmadreza Moradipari et al.

Flow-matching generative models have emerged as a powerful paradigm for continuous data generation, achieving state-of-the-art results across domains such as images, 3D shapes, and point clouds. Despite their success, these models suffer from slow inference due to the requirement of numerous sequential sampling steps. Recent work has sought to accelerate inference by reducing the number of sampling steps. In particular, Mean Flows offer a one-step generation approach that delivers substantial speedups while retaining strong generative performance. Yet, in many continuous domains, Mean Flows fail to faithfully approximate the behavior of the original multi-step flow-matching process. In this work, we address this limitation by incorporating optimal transport-based sampling strategies into the Mean Flow framework, enabling one-step generators that better preserve the fidelity and diversity of the original multi-step flow process. Experiments on controlled low-dimensional settings and on high-dimensional tasks such as image generation, image-to-image translation, and point cloud generation demonstrate that our approach achieves superior inference accuracy in one-step generative modeling.

LGNov 2, 2021
Understanding Entropic Regularization in GANs

Daria Reshetova, Yikun Bai, Xiugang Wu et al.

Generative Adversarial Networks are a popular method for learning distributions from data by modeling the target distribution as a function of a known distribution. The function, often referred to as the generator, is optimized to minimize a chosen distance measure between the generated and target distributions. One commonly used measure for this purpose is the Wasserstein distance. However, Wasserstein distance is hard to compute and optimize, and in practice entropic regularization techniques are used to improve numerical convergence. The influence of regularization on the learned solution, however, remains not well-understood. In this paper, we study how several popular entropic regularizations of Wasserstein distance impact the solution in a simple benchmark setting where the generator is linear and the target distribution is high-dimensional Gaussian. We show that entropy regularization promotes the solution sparsification, while replacing the Wasserstein distance with the Sinkhorn divergence recovers the unregularized solution. Both regularization techniques remove the curse of dimensionality suffered by Wasserstein distance. We show that the optimal generator can be learned to accuracy $ε$ with $O(1/ε^2)$ samples from the target distribution. We thus conclude that these regularization techniques can improve the quality of the generator learned from empirical data for a large class of distributions.

ITAug 24, 2020
Information Constrained Optimal Transport: From Talagrand, to Marton, to Cover

Yikun Bai, Xiugang Wu, Ayfer Ozgur

The optimal transport problem studies how to transport one measure to another in the most cost-effective way and has wide range of applications from economics to machine learning. In this paper, we introduce and study an information constrained variation of this problem. Our study yields a strengthening and generalization of Talagrand's celebrated transportation cost inequality. Following Marton's approach, we show that the new transportation cost inequality can be used to recover old and new concentration of measure results. Finally, we provide an application of this new inequality to network information theory. We show that it can be used to recover almost immediately a recent solution to a long-standing open problem posed by Cover regarding the capacity of the relay channel.