LGMLFeb 18, 2021

A Mathematical Principle of Deep Learning: Learn the Geodesic Curve in the Wasserstein Space

arXiv:2102.09235v210 citations
AI Analysis

This provides a foundational mathematical principle for deep learning, potentially affecting all of ML/AI by explaining why architectures like ResNet perform better.

The paper tackles the problem of characterizing the fundamental principle of deep neural networks (DNNs) in terms of optimization and generalization by connecting DNNs to continuity equations and optimal transport theory. It finds that DNNs with weight decay learn the geodesic curve in Wasserstein space, with ResNet better approximating this curve, leading to improved optimization and generalization as shown by line-shape and optimal transport scores.

Recent studies revealed the mathematical connection of deep neural network (DNN) and dynamic system. However, the fundamental principle of DNN has not been fully characterized with dynamic system in terms of optimization and generalization. To this end, we build the connection of DNN and continuity equation where the measure is conserved to model the forward propagation process of DNN which has not been addressed before. DNN learns the transformation of the input distribution to the output one. However, in the measure space, there are infinite curves connecting two distributions. Which one can lead to good optimization and generaliztion for DNN? By diving the optimal transport theory, we find DNN with weight decay attempts to learn the geodesic curve in the Wasserstein space, which is induced by the optimal transport map. Compared with plain network, ResNet is a better approximation to the geodesic curve, which explains why ResNet can be optimized and generalize better. Numerical experiments show that the data tracks of both plain network and ResNet tend to be line-shape in term of line-shape score (LSS), and the map learned by ResNet is closer to the optimal transport map in term of optimal transport score (OTS). In a word, we conclude a mathematical principle of deep learning is to learn the geodesic curve in the Wasserstein space; and deep learning is a great engineering realization of continuous transformation in high-dimensional space.

Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes