PolicyClusterGCN: Identifying Efficient Clusters for Training Graph Convolutional Networks
This work addresses the challenge of efficient GCN training for researchers and practitioners in graph-based machine learning, offering a novel method that outperforms existing state-of-the-art models, though it is incremental as it builds upon prior subgraph-based sampling approaches.
The authors tackled the problem of inefficient cluster identification for training Graph Convolutional Networks (GCNs) by proposing PolicyClusterGCN, a reinforcement learning framework that learns to predict edge weights for clustering, resulting in improved performance on node classification tasks across multiple datasets.
Graph convolutional networks (GCNs) have achieved huge success in several machine learning (ML) tasks on graph-structured data. Recently, several sampling techniques have been proposed for the efficient training of GCNs and to improve the performance of GCNs on ML tasks. Specifically, the subgraph-based sampling approaches such as ClusterGCN and GraphSAINT have achieved state-of-the-art performance on the node classification tasks. These subgraph-based sampling approaches rely on heuristics -- such as graph partitioning via edge cuts -- to identify clusters that are then treated as minibatches during GCN training. In this work, we hypothesize that rather than relying on such heuristics, one can learn a reinforcement learning (RL) policy to compute efficient clusters that lead to effective GCN performance. To that end, we propose PolicyClusterGCN, an online RL framework that can identify good clusters for GCN training. We develop a novel Markov Decision Process (MDP) formulation that allows the policy network to predict ``importance" weights on the edges which are then utilized by a clustering algorithm (Graclus) to compute the clusters. We train the policy network using a standard policy gradient algorithm where the rewards are computed from the classification accuracies while training GCN using clusters given by the policy. Experiments on six real-world datasets and several synthetic datasets show that PolicyClusterGCN outperforms existing state-of-the-art models on node classification task.