Michal Klein

LG
h-index39
19papers
421citations
Novelty53%
AI Score58

19 Papers

MLFeb 8, 2023
Monge, Bregman and Occam: Interpretable Optimal Transport in High-Dimensions with Feature-Sparse Maps

Marco Cuturi, Michal Klein, Pierre Ablin · apple-ml

Optimal transport (OT) theory focuses, among all maps $T:\mathbb{R}^d\rightarrow \mathbb{R}^d$ that can morph a probability measure onto another, on those that are the ``thriftiest'', i.e. such that the averaged cost $c(x, T(x))$ between $x$ and its image $T(x)$ be as small as possible. Many computational approaches have been proposed to estimate such Monge maps when $c$ is the $\ell_2^2$ distance, e.g., using entropic maps [Pooladian'22], or neural networks [Makkuva'20, Korotin'20]. We propose a new model for transport maps, built on a family of translation invariant costs $c(x, y):=h(x-y)$, where $h:=\tfrac{1}{2}\|\cdot\|_2^2+τ$ and $τ$ is a regularizer. We propose a generalization of the entropic map suitable for $h$, and highlight a surprising link tying it with the Bregman centroids of the divergence $D_h$ generated by $h$, and the proximal operator of $τ$. We show that choosing a sparsity-inducing norm for $τ$ results in maps that apply Occam's razor to transport, in the sense that the displacement vectors $Δ(x):= T(x)-x$ they induce are sparse, with a sparsity pattern that varies depending on $x$. We showcase the ability of our method to estimate meaningful OT maps for high-dimensional single-cell transcription data, in the $34000$-$d$ space of gene counts for cells, without using dimensionality reduction, thus retaining the ability to interpret all displacements at the gene level.

MLJun 20, 2023
Learning Elastic Costs to Shape Monge Displacements

Michal Klein, Aram-Alexandre Pooladian, Pierre Ablin et al. · apple-ml

Given a source and a target probability measure supported on $\mathbb{R}^d$, the Monge problem asks to find the most efficient way to map one distribution to the other. This efficiency is quantified by defining a \textit{cost} function between source and target data. Such a cost is often set by default in the machine learning literature to the squared-Euclidean distance, $\ell^2_2(\mathbf{x},\mathbf{y})=\tfrac12|\mathbf{x}-\mathbf{y}|_2^2$. Recently, Cuturi et. al '23 highlighted the benefits of using elastic costs, defined through a regularizer $τ$ as $c(\mathbf{x},\mathbf{y})=\ell^2_2(\mathbf{x},\mathbf{y})+τ(\mathbf{x}-\mathbf{y})$. Such costs shape the \textit{displacements} of Monge maps $T$, i.e., the difference between a source point and its image $T(\mathbf{x})-\mathbf{x})$, by giving them a structure that matches that of the proximal operator of $τ$. In this work, we make two important contributions to the study of elastic costs: (i) For any elastic cost, we propose a numerical method to compute Monge maps that are provably optimal. This provides a much-needed routine to create synthetic problems where the ground truth OT map is known, by analogy to the Brenier theorem, which states that the gradient of any convex potential is always a valid Monge map for the $\ell_2^2$ cost; (ii) We propose a loss to \textit{learn} the parameter $θ$ of a parameterized regularizer $τ_θ$, and apply it in the case where $τ_{A}(\mathbf{z})=|A^\perp \mathbf{z}|^2_2$. This regularizer promotes displacements that lie on a low dimensional subspace of $\mathbb{R}^d$, spanned by the $p$ rows of $A\in\mathbb{R}^{p\times d}$.

LGJul 24, 2023
Predicting Ordinary Differential Equations with Transformers

Sören Becker, Michal Klein, Alexander Neitz et al.

We develop a transformer-based sequence-to-sequence model that recovers scalar ordinary differential equations (ODEs) in symbolic form from irregularly sampled and noisy observations of a single solution trajectory. We demonstrate in extensive empirical evaluations that our model performs better or on par with existing methods in terms of accurate recovery across various settings. Moreover, our method is efficiently scalable: after one-time pretraining on a large set of ODEs, we can infer the governing law of a new observed solution in a few forward passes of the model.

LGNov 5, 2022
Discovering ordinary differential equations that govern time-series

Sören Becker, Michal Klein, Alexander Neitz et al.

Natural laws are often described through differential equations yet finding a differential equation that describes the governing law underlying observed data is a challenging and still mostly manual task. In this paper we make a step towards the automation of this process: we propose a transformer-based sequence-to-sequence model that recovers scalar autonomous ordinary differential equations (ODEs) in symbolic form from time-series data of a single observed solution of the ODE. Our method is efficiently scalable: after one-time pretraining on a large set of ODEs, we can infer the governing laws of a new observed solution in a few forward passes of the model. Then we show that our model performs better or on par with existing methods in various test cases in terms of accurate symbolic recovery of the ODE, especially for more complex expressions.

LGDec 26, 2025
Completed Hyperparameter Transfer across Modules, Width, Depth, Batch and Duration

Bruno Mlodozeniec, Pierre Ablin, Louis Béthune et al.

Hyperparameter tuning can dramatically impact training stability and final performance of large-scale models. Recent works on neural network parameterisations, such as $μ$P, have enabled transfer of optimal global hyperparameters across model sizes. These works propose an empirical practice of search for optimal global base hyperparameters at a small model size, and transfer to a large size. We extend these works in two key ways. To handle scaling along most important scaling axes, we propose the Complete$^{(d)}$ Parameterisation that unifies scaling in width and depth -- using an adaptation of CompleteP -- as well as in batch-size and training duration. Secondly, with our parameterisation, we investigate per-module hyperparameter optimisation and transfer. We characterise the empirical challenges of navigating the high-dimensional hyperparameter landscape, and propose practical guidelines for tackling this optimisation problem. We demonstrate that, with the right parameterisation, hyperparameter transfer holds even in the per-module hyperparameter regime. Our study covers an extensive range of optimisation hyperparameters of modern models: learning rates, AdamW parameters, weight decay, initialisation scales, and residual block multipliers. Our experiments demonstrate significant training speed improvements in Large Language Models with the transferred per-module hyperparameters.

LGMay 10
Nectar: Neural Estimation of Cached-Token Attention via Regression

João Monteiro, Michal Klein, Pierre Ablin et al.

Evaluating softmax attention over a fixed long context requires reading every cached key-value pair for each new query token. For a given context (a book, a manual, a legal corpus) the attention output is a deterministic function of the query. We propose Nectar, which fits a compact neural network to this function for queries drawn from a task-relevant distribution. Nectar fits two networks per layer and KV-head: a target network that predicts the attention output and a score network that predicts the log-normalizer. The pair plugs into the standard masked self-attention at inference time, replacing the $O(n)$ attention over the cache with a forward pass whose cost does not depend on $n$. Each module carries on the order of $|θ|$ parameters per layer and KV-head, typically much smaller than the $2nd$ KV-cache footprint at the same granularity. We report experiments on models from 1.7B to 8B parameters across five long-context datasets. The approximation error tracks the next-token accuracy gap to full attention, and allocating capacity non-uniformly across layers reduces that gap in our ablation. Beyond this analysis of metrics, we check that the text generations (following a question prompt) of a model equipped with a Nectar module match in semantic content those obtained by giving the same model access to the full cache.

CVNov 21, 2024
Multimodal Autoregressive Pre-training of Large Vision Encoders

Enrico Fini, Mustafa Shukor, Xiujun Li et al.

We introduce a novel method for pre-training of large-scale vision encoders. Building on recent advancements in autoregressive pre-training of vision models, we extend this framework to a multimodal setting, i.e., images and text. In this paper, we present AIMV2, a family of generalist vision encoders characterized by a straightforward pre-training process, scalability, and remarkable performance across a range of downstream tasks. This is achieved by pairing the vision encoder with a multimodal decoder that autoregressively generates raw image patches and text tokens. Our encoders excel not only in multimodal evaluations but also in vision benchmarks such as localization, grounding, and classification. Notably, our AIMV2-3B encoder achieves 89.5% accuracy on ImageNet-1k with a frozen trunk. Furthermore, AIMV2 consistently outperforms state-of-the-art contrastive models (e.g., CLIP, SigLIP) in multimodal image understanding across diverse settings.

LGOct 30, 2024
Controlling Language and Diffusion Models by Transporting Activations

Pau Rodriguez, Arno Blaas, Michal Klein et al.

The increasing capabilities of large generative models and their ever more widespread deployment have raised concerns about their reliability, safety, and potential misuse. To address these issues, recent works have proposed to control model generation by steering model activations in order to effectively induce or prevent the emergence of concepts or behaviors in the generated output. In this paper we introduce Activation Transport (AcT), a general framework to steer activations guided by optimal transport theory that generalizes many previous activation-steering works. AcT is modality-agnostic and provides fine-grained control over the model behavior with negligible computational overhead, while minimally impacting model abilities. We experimentally show the effectiveness and versatility of our approach by addressing key challenges in large language models (LLMs) and text-to-image diffusion models (T2Is). For LLMs, we show that AcT can effectively mitigate toxicity, induce arbitrary concepts, and increase their truthfulness. In T2Is, we show how AcT enables fine-grained style control and concept negation.

MLFeb 5, 2025
Multivariate Conformal Prediction using Optimal Transport

Michal Klein, Louis Bethune, Eugene Ndiaye et al.

Conformal prediction (CP) quantifies the uncertainty of machine learning models by constructing sets of plausible outputs. These sets are constructed by leveraging a so-called conformity score, a quantity computed using the input point of interest, a prediction model, and past observations. CP sets are then obtained by evaluating the conformity score of all possible outputs, and selecting them according to the rank of their scores. Due to this ranking step, most CP approaches rely on a score functions that are univariate. The challenge in extending these scores to multivariate spaces lies in the fact that no canonical order for vectors exists. To address this, we leverage a natural extension of multivariate score ranking based on optimal transport (OT). Our method, OTCP, offers a principled framework for constructing conformal prediction sets in multidimensional settings, preserving distribution-free coverage guarantees with finite data samples. We demonstrate tangible gains in a benchmark dataset of multivariate regression problems and address computational \& statistical trade-offs that arise when estimating conformity scores through OT maps.

LGJun 5, 2025
On Fitting Flow Models with Large Sinkhorn Couplings

Stephen Zhang, Alireza Mousavi-Hosseini, Michal Klein et al.

Flow models transform data gradually from one modality (e.g. noise) onto another (e.g. images). Such models are parameterized by a time-dependent velocity field, trained to fit segments connecting pairs of source and target points. When the pairing between source and target points is given, training flow models boils down to a supervised regression problem. When no such pairing exists, as is the case when generating data from noise, training flows is much harder. A popular approach lies in picking source and target points independently. This can, however, lead to velocity fields that are slow to train, but also costly to integrate at inference time. In theory, one would greatly benefit from training flow models by sampling pairs from an optimal transport (OT) measure coupling source and target, since this would lead to a highly efficient flow solving the Benamou and Brenier dynamical OT problem. In practice, recent works have proposed to sample mini-batches of $n$ source and $n$ target points and reorder them using an OT solver to form better pairs. These works have advocated using batches of size $n\approx 256$, and considered OT solvers that return couplings that are either sharp (using e.g. the Hungarian algorithm) or blurred (using e.g. entropic regularization, a.k.a. Sinkhorn). We follow in the footsteps of these works by exploring the benefits of increasing $n$ by three to four orders of magnitude, and look more carefully on the effect of the entropic regularization $\varepsilon$ used in the Sinkhorn algorithm. Our analysis is facilitated by new scale invariant quantities to report the sharpness of a coupling, while our sharded computations across multiple GPU or GPU nodes allow scaling up $n$. We show that in both synthetic and image generation tasks, flow models greatly benefit when fitted with large Sinkhorn couplings, with a low entropic regularization $\varepsilon$.

LGMar 9
Amortizing Maximum Inner Product Search with Learned Support Functions

Theo X. Olausson, João Monteiro, Michal Klein et al.

Maximum inner product search (MIPS) is a crucial subroutine in machine learning, requiring the identification of key vectors that align best with a given query. We propose amortized MIPS: a learning-based approach that trains neural networks to directly predict MIPS solutions, amortizing the computational cost of matching queries (drawn from a fixed distribution) to a fixed set of keys. Our key insight is that the MIPS value function, the maximal inner product between a query and keys, is also known as the support function of the set of keys. Support functions are convex, 1-homogeneous and their gradient w.r.t. the query is exactly the optimal key in the database. We approximate the support function using two complementary approaches: (1) we train an input-convex neural network (SupportNet) to model the support function directly; the optimal key can be recovered via (autodiff) gradient computation, and (2) we regress directly the optimal key from the query using a vector valued network (KeyNet), bypassing gradient computation entirely at inference time. To learn a SupportNet, we combine score regression with gradient matching losses, and propose homogenization wrappers that enforce the positive 1-homogeneity of a neural network, theoretically linking function values to gradients. To train a KeyNet, we introduce a score consistency loss derived from the Euler theorem for homogeneous functions. Our experiments show that learned SupportNet or KeyNet achieve high match rates and open up new directions to compress databases with a specific query distribution in mind.

LGSep 29, 2025
Flow Matching with Semidiscrete Couplings

Alireza Mousavi-Hosseini, Stephen Y. Zhang, Michal Klein et al.

Flow models parameterized as time-dependent velocity fields can generate data from noise by integrating an ODE. These models are often trained using flow matching, i.e. by sampling random pairs of noise and target points $(\mathbf{x}_0,\mathbf{x}_1)$ and ensuring that the velocity field is aligned, on average, with $\mathbf{x}_1-\mathbf{x}_0$ when evaluated along a segment linking $\mathbf{x}_0$ to $\mathbf{x}_1$. While these pairs are sampled independently by default, they can also be selected more carefully by matching batches of $n$ noise to $n$ target points using an optimal transport (OT) solver. Although promising in theory, the OT flow matching (OT-FM) approach is not widely used in practice. Zhang et al. (2025) pointed out recently that OT-FM truly starts paying off when the batch size $n$ grows significantly, which only a multi-GPU implementation of the Sinkhorn algorithm can handle. Unfortunately, the costs of running Sinkhorn can quickly balloon, requiring $O(n^2/\varepsilon^2)$ operations for every $n$ pairs used to fit the velocity field, where $\varepsilon$ is a regularization parameter that should be typically small to yield better results. To fulfill the theoretical promises of OT-FM, we propose to move away from batch-OT and rely instead on a semidiscrete formulation that leverages the fact that the target dataset distribution is usually of finite size $N$. The SD-OT problem is solved by estimating a dual potential vector using SGD; using that vector, freshly sampled noise vectors at train time can then be matched with data points at the cost of a maximum inner product search (MIPS). Semidiscrete FM (SD-FM) removes the quadratic dependency on $n/\varepsilon$ that bottlenecks OT-FM. SD-FM beats both FM and OT-FM on all training metrics and inference budget constraints, across multiple datasets, on unconditional/conditional generation, or when using mean-flow models.

LGJun 10, 2025
The Geometries of Truth Are Orthogonal Across Tasks

Waiss Azizian, Michael Kirchhof, Eugene Ndiaye et al.

Large Language Models (LLMs) have demonstrated impressive generalization capabilities across various tasks, but their claim to practical relevance is still mired by concerns on their reliability. Recent works have proposed examining the activations produced by an LLM at inference time to assess whether its answer to a question is correct. Some works claim that a "geometry of truth" can be learned from examples, in the sense that the activations that generate correct answers can be distinguished from those leading to mistakes with a linear classifier. In this work, we underline a limitation of these approaches: we observe that these "geometries of truth" are intrinsically task-dependent and fail to transfer across tasks. More precisely, we show that linear classifiers trained across distinct tasks share little similarity and, when trained with sparsity-enforcing regularizers, have almost disjoint supports. We show that more sophisticated approaches (e.g., using mixtures of probes and tasks) fail to overcome this limitation, likely because activation vectors commonly used to classify answers form clearly separated clusters when examined across tasks.

LGFeb 5, 2024
Careful with that Scalpel: Improving Gradient Surgery with an EMA

Yu-Guan Hsieh, James Thornton, Eugene Ndiaye et al.

Beyond minimizing a single training loss, many deep learning estimation pipelines rely on an auxiliary objective to quantify and encourage desirable properties of the model (e.g. performance on another dataset, robustness, agreement with a prior). Although the simplest approach to incorporating an auxiliary loss is to sum it with the training loss as a regularizer, recent works have shown that one can improve performance by blending the gradients beyond a simple sum; this is known as gradient surgery. We cast the problem as a constrained minimization problem where the auxiliary objective is minimized among the set of minimizers of the training loss. To solve this bilevel problem, we follow a parameter update direction that combines the training loss gradient and the orthogonal projection of the auxiliary gradient to the training gradient. In a setting where gradients come from mini-batches, we explain how, using a moving average of the training loss gradients, we can carefully maintain this critical orthogonality property. We demonstrate that our method, Bloop, can lead to much better performances on NLP and vision experiments than other gradient surgery methods without EMA.

LGOct 1, 2025
The Data-Quality Illusion: Rethinking Classifier-Based Quality Filtering for LLM Pretraining

Thiziri Nait Saada, Louis Bethune, Michal Klein et al.

Large-scale models are pretrained on massive web-crawled datasets containing documents of mixed quality, making data filtering essential. A popular method is Classifier-based Quality Filtering (CQF), which trains a binary classifier to distinguish between pretraining data and a small, high-quality set. It assigns each pretraining document a quality score defined as the classifier's score and retains only the top-scoring ones. We provide an in-depth analysis of CQF. We show that while CQF improves downstream task performance, it does not necessarily enhance language modeling on the high-quality dataset. We explain this paradox by the fact that CQF implicitly filters the high-quality dataset as well. We further compare the behavior of models trained with CQF to those trained on synthetic data of increasing quality, obtained via random token permutations, and find starkly different trends. Our results challenge the view that CQF captures a meaningful notion of data quality.

CLMar 11, 2025
LinEAS: End-to-end Learning of Activation Steering with a Distributional Loss

Pau Rodriguez, Michal Klein, Eleonora Gualdoni et al.

The growing use of generative models in daily life calls for efficient mechanisms to control their generation, to e.g., produce safe content or provide users with tools to explore style changes. Ideally, such mechanisms should require low volume of unpaired data (i.e., without explicit preference), and should be cheap, both at train and inference time, while preserving output quality. Recent research has shown that such mechanisms can be obtained by intervening exclusively on model activations, with the goal of correcting distributional differences between activations seen when using prompts from a source vs. a target set (e.g., toxic and non-toxic sentences). While cheap, these fast methods are inherently crude: their maps are tuned locally, not accounting for their impact on downstream layers, resulting in interventions that cause unintended shifts when used out-of-sample. We propose in this work linear end-to-end activation steering (LinEAS), an approach trained with a global loss that accounts simultaneously for all layer-wise distributional shifts. In addition to being more robust, the loss used to train LinEAS can be regularized with sparsifying norms, which can automatically carry out neuron selection. LinEAS only requires a handful of unpaired samples to be effective, and beats similar baselines on toxicity mitigation in language models, becoming competitive with oracle-dependent methods that have access to strong supervision. LinEAS is modality-agnostic and we empirically find that it outperforms existing activation steering methods at mitigating and including new concepts at the output of single-step text-to-image generation models.

MLJun 7, 2024
Progressive Entropic Optimal Transport Solvers

Parnian Kassraie, Aram-Alexandre Pooladian, Michal Klein et al.

Optimal transport (OT) has profoundly impacted machine learning by providing theoretical and computational tools to realign datasets. In this context, given two large point clouds of sizes $n$ and $m$ in $\mathbb{R}^d$, entropic OT (EOT) solvers have emerged as the most reliable tool to either solve the Kantorovich problem and output a $n\times m$ coupling matrix, or to solve the Monge problem and learn a vector-valued push-forward map. While the robustness of EOT couplings/maps makes them a go-to choice in practical applications, EOT solvers remain difficult to tune because of a small but influential set of hyperparameters, notably the omnipresent entropic regularization strength $\varepsilon$. Setting $\varepsilon$ can be difficult, as it simultaneously impacts various performance metrics, such as compute speed, statistical performance, generalization, and bias. In this work, we propose a new class of EOT solvers (ProgOT), that can estimate both plans and transport maps. We take advantage of several opportunities to optimize the computation of EOT solutions by dividing mass displacement using a time discretization, borrowing inspiration from dynamic OT formulations, and conquering each of these steps using EOT with properly scheduled parameters. We provide experimental evidence demonstrating that ProgOT is a faster and more robust alternative to standard solvers when computing couplings at large scales, even outperforming neural network-based approaches. We also prove statistical consistency of our approach for estimating optimal transport maps.

CVJan 16, 2024
Scalable Pre-training of Large Autoregressive Image Models

Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai et al.

This paper introduces AIM, a collection of vision models pre-trained with an autoregressive objective. These models are inspired by their textual counterparts, i.e., Large Language Models (LLMs), and exhibit similar scaling properties. Specifically, we highlight two key findings: (1) the performance of the visual features scale with both the model capacity and the quantity of data, (2) the value of the objective function correlates with the performance of the model on downstream tasks. We illustrate the practical implication of these findings by pre-training a 7 billion parameter AIM on 2 billion images, that achieves 84.0% on ImageNet-1k with a frozen trunk. Interestingly, even at this scale, we observe no sign of saturation in performance, suggesting that AIM potentially represents a new frontier for training large-scale vision models. The pre-training of AIM is similar to the pre-training of LLMs, and does not require any image-specific strategy to stabilize the training at scale.

LGMay 31, 2023
Unbalanced Low-rank Optimal Transport Solvers

Meyer Scetbon, Michal Klein, Giovanni Palla et al.

The relevance of optimal transport methods to machine learning has long been hindered by two salient limitations. First, the $O(n^3)$ computational cost of standard sample-based solvers (when used on batches of $n$ samples) is prohibitive. Second, the mass conservation constraint makes OT solvers too rigid in practice: because they must match \textit{all} points from both measures, their output can be heavily influenced by outliers. A flurry of recent works in OT has addressed these computational and modelling limitations, but has resulted in two separate strains of methods: While the computational outlook was much improved by entropic regularization, more recent $O(n)$ linear-time \textit{low-rank} solvers hold the promise to scale up OT further. On the other hand, modelling rigidities have been eased owing to unbalanced variants of OT, that rely on penalization terms to promote, rather than impose, mass conservation. The goal of this paper is to merge these two strains, to achieve the promise of \textit{both} versatile/scalable unbalanced/low-rank OT solvers. We propose custom algorithms to implement these extensions for the linear OT problem and its Fused-Gromov-Wasserstein generalization, and demonstrate their practical relevance to challenging spatial transcriptomics matching problems.