A Galois theorem for machine learning: Functions on symmetric matrices and point clouds via lightweight invariant features
This work addresses the challenge of designing efficient invariant representations for machine learning on structured data like graphs and point clouds, which is incremental in applying mathematical theory to improve model expressivity.
The authors tackled the problem of learning functions on symmetric matrices and point clouds that are invariant to permutations and other symmetries by constructing lightweight invariant features inspired by Galois theory, achieving generic separation with O(n^2) features for matrices and O(n) for point clouds, and demonstrated feasibility on molecule property regression and point cloud distance prediction tasks.
In this work, we present a mathematical formulation for machine learning of (1) functions on symmetric matrices that are invariant with respect to the action of permutations by conjugation, and (2) functions on point clouds that are invariant with respect to rotations, reflections, and permutations of the points. To achieve this, we provide a general construction of generically separating invariant features using ideas inspired by Galois theory. We construct $O(n^2)$ invariant features derived from generators for the field of rational functions on $n\times n$ symmetric matrices that are invariant under joint permutations of rows and columns. We show that these invariant features can separate all distinct orbits of symmetric matrices except for a measure zero set; such features can be used to universally approximate invariant functions on almost all weighted graphs. For point clouds in a fixed dimension, we prove that the number of invariant features can be reduced, generically without losing expressivity, to $O(n)$, where $n$ is the number of points. We combine these invariant features with DeepSets to learn functions on symmetric matrices and point clouds with varying sizes. We empirically demonstrate the feasibility of our approach on molecule property regression and point cloud distance prediction.