Differentiable Expectation-Maximisation and Applications to Gaussian Mixture Model Optimal Transport
This work addresses a bottleneck for researchers and practitioners in machine learning who need differentiable latent-variable models, though it is incremental as it builds on existing EM and optimal transport methods.
The paper tackles the non-differentiability of the Expectation-Maximisation algorithm by presenting and comparing differentiation strategies, enabling its use in end-to-end learning pipelines, and applies this to compute the Mixture Wasserstein distance for tasks like image generation and style transfer, achieving practical results in numerical experiments.
The Expectation-Maximisation (EM) algorithm is a central tool in statistics and machine learning, widely used for latent-variable models such as Gaussian Mixture Models (GMMs). Despite its ubiquity, EM is typically treated as a non-differentiable black box, preventing its integration into modern learning pipelines where end-to-end gradient propagation is essential. In this work, we present and compare several differentiation strategies for EM, from full automatic differentiation to approximate methods, assessing their accuracy and computational efficiency. As a key application, we leverage this differentiable EM in the computation of the Mixture Wasserstein distance $\mathrm{MW}_2$ between GMMs, allowing $\mathrm{MW}_2$ to be used as a differentiable loss in imaging and machine learning tasks. To complement our practical use of $\mathrm{MW}_2$, we contribute a novel stability result which provides theoretical justification for the use of $\mathrm{MW}_2$ with EM, and also introduce a novel unbalanced variant of $\mathrm{MW}_2$. Numerical experiments on barycentre computation, colour and style transfer, image generation, and texture synthesis illustrate the versatility of the proposed approach in different settings.