LGMLOct 25, 2022

Proximal Mean Field Learning in Shallow Neural Networks

arXiv:2210.13879v31 citationsh-index: 17
Originality Incremental advance
AI Analysis

This work addresses computational challenges in training shallow neural networks for classification, but it is incremental as it builds on existing mean field interpretations.

The authors tackled the problem of learning in shallow over-parameterized neural networks by proposing a computational algorithm based on mean field theory, resulting in a meshless method for binary and multi-class classification that performs gradient descent on the free energy.

We propose a custom learning algorithm for shallow over-parameterized neural networks, i.e., networks with single hidden layer having infinite width. The infinite width of the hidden layer serves as an abstraction for the over-parameterization. Building on the recent mean field interpretations of learning dynamics in shallow neural networks, we realize mean field learning as a computational algorithm, rather than as an analytical tool. Specifically, we design a Sinkhorn regularized proximal algorithm to approximate the distributional flow for the learning dynamics over weighted point clouds. In this setting, a contractive fixed point recursion computes the time-varying weights, numerically realizing the interacting Wasserstein gradient flow of the parameter distribution supported over the neuronal ensemble. An appealing aspect of the proposed algorithm is that the measure-valued recursions allow meshless computation. We demonstrate the proposed computational framework of interacting weighted particle evolution on binary and multi-class classification. Our algorithm performs gradient descent of the free energy associated with the risk functional.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes