MLLGJun 8, 2020

Optimal Transport Graph Neural Networks

arXiv:2006.04804v649 citations
AI Analysis

This addresses a bottleneck in graph representation learning for domains like molecular property prediction, though it is an incremental improvement combining existing techniques.

The paper tackles the problem of naive aggregation in graph neural networks (GNNs) losing structural information by introducing OT-GNN, which uses optimal transport and parametric prototypes to compute graph embeddings, outperforming popular methods on molecular property prediction tasks with smoother representations.

Current graph neural network (GNN) architectures naively average or sum node embeddings into an aggregated graph representation -- potentially losing structural or semantic information. We here introduce OT-GNN, a model that computes graph embeddings using parametric prototypes that highlight key facets of different graph aspects. Towards this goal, we successfully combine optimal transport (OT) with parametric graph models. Graph representations are obtained from Wasserstein distances between the set of GNN node embeddings and ``prototype'' point clouds as free parameters. We theoretically prove that, unlike traditional sum aggregation, our function class on point clouds satisfies a fundamental universal approximation theorem. Empirically, we address an inherent collapse optimization issue by proposing a noise contrastive regularizer to steer the model towards truly exploiting the OT geometry. Finally, we outperform popular methods on several molecular property prediction tasks, while exhibiting smoother graph representations.

Code Implementations2 repos
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes