LGDec 16, 2024

On the Ability of Deep Networks to Learn Symmetries from Data: A Neural Kernel Theory

arXiv:2412.11521v29 citationsh-index: 3Has Code
Originality Incremental advance
AI Analysis

This addresses the problem of improving prediction in machine learning by leveraging symmetries, but reveals that standard networks cannot learn unembedded symmetries, which is incremental for theory and architecture design.

The paper investigates when deep networks can learn symmetries from data, finding that generalization fails unless the architecture matches the inherent symmetry, as shown through a neural kernel theory and experiments on rotated-MNIST.

Symmetries (transformations by group actions) are present in many datasets, and leveraging them holds considerable promise for improving predictions in machine learning. In this work, we aim to understand when and how deep networks -- with standard architectures trained in a standard, supervised way -- learn symmetries from data. Inspired by real-world scenarios, we study a classification paradigm where data symmetries are only partially observed during training: some classes include all transformations of a cyclic group, while others -- only a subset. In the infinite-width limit, where kernel analogies apply, we derive a neural kernel theory of symmetry learning. The group-cyclic nature of the dataset allows us to analyze the Gram matrix of neural kernels in the Fourier domain; here we find a simple characterization of the generalization error as a function of class separation (signal) and class-orbit density (noise). This characterization reveals that generalization can only be successful when the local structure of the data prevails over its non-local, symmetry-induced structure, in the kernel space defined by the architecture. We extend our theoretical treatment to any finite group, including non-abelian groups. Our framework also applies to equivariant architectures (e.g., CNNs), and recovers their success in the special case where the architecture matches the inherent symmetry of the data. Empirically, our theory reproduces the generalization failure of finite-width networks (MLP, CNN, ViT) trained on partially observed versions of rotated-MNIST. We conclude that conventional deep networks lack a mechanism to learn symmetries that have not been explicitly embedded in their architecture a priori. Our framework could be extended to guide the design of architectures and training procedures able to learn symmetries from data.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes