Supervised Training of Conditional Monge Maps
This work addresses the need for context-aware optimal transport in applications like predicting cell responses to treatments, representing an incremental advance by extending OT to conditioned settings.
The paper tackles the problem of estimating optimal transport maps conditioned on context variables, such as predicting cell responses to treatments, by introducing CondOT, a multi-task approach that learns a global map to fit labeled pairs and generalize to unseen contexts, demonstrating its ability to infer effects of arbitrary perturbations on single cells using only separate observations.
Optimal transport (OT) theory describes general principles to define and select, among many possible choices, the most efficient way to map a probability measure onto another. That theory has been mostly used to estimate, given a pair of source and target probability measures $(μ, ν)$, a parameterized map $T_θ$ that can efficiently map $μ$ onto $ν$. In many applications, such as predicting cell responses to treatments, pairs of input/output data measures $(μ, ν)$ that define optimal transport problems do not arise in isolation but are associated with a context $c$, as for instance a treatment when comparing populations of untreated and treated cells. To account for that context in OT estimation, we introduce CondOT, a multi-task approach to estimate a family of OT maps conditioned on a context variable, using several pairs of measures $\left(μ_i, ν_i\right)$ tagged with a context label $c_i$. CondOT learns a global map $\mathcal{T}_θ$ conditioned on context that is not only expected to fit all labeled pairs in the dataset $\left\{\left(c_i,\left(μ_i, ν_i\right)\right)\right\}$, i.e., $\mathcal{T}_θ\left(c_i\right) \sharp μ_i \approx ν_i$, but should also generalize to produce meaningful maps $\mathcal{T}_θ\left(c_{\text {new }}\right)$ when conditioned on unseen contexts $c_{\text {new }}$. Our approach harnesses and provides a novel usage for partially input convex neural networks, for which we introduce a robust and efficient initialization strategy inspired by Gaussian approximations. We demonstrate the ability of CondOT to infer the effect of an arbitrary combination of genetic or therapeutic perturbations on single cells, using only observations of the effects of said perturbations separately.