Sampling-based Distributed Training with Message Passing Neural Network
This addresses the problem of scaling graph neural networks for large datasets in computational physics, such as fluid dynamics, though it appears incremental as it builds on existing MPNN and sampling techniques.
The study tackled scaling edge-based graph neural networks for large graphs by introducing a distributed training and inference approach with sampling, achieving comparable accuracy to single-GPU methods while handling up to 100,000 nodes and outperforming node-based GCNs.
In this study, we introduce a domain-decomposition-based distributed training and inference approach for message-passing neural networks (MPNN). Our objective is to address the challenge of scaling edge-based graph neural networks as the number of nodes increases. Through our distributed training approach, coupled with Nyström-approximation sampling techniques, we present a scalable graph neural network, referred to as DS-MPNN (D and S standing for distributed and sampled, respectively), capable of scaling up to $O(10^5)$ nodes. We validate our sampling and distributed training approach on two cases: (a) a Darcy flow dataset and (b) steady RANS simulations of 2-D airfoils, providing comparisons with both single-GPU implementation and node-based graph convolution networks (GCNs). The DS-MPNN model demonstrates comparable accuracy to single-GPU implementation, can accommodate a significantly larger number of nodes compared to the single-GPU variant (S-MPNN), and significantly outperforms the node-based GCN.