LGAIDec 30, 2021

Causal Attention for Interpretable and Generalizable Graph Classification

arXiv:2112.15089v2235 citations
Originality Incremental advance
AI Analysis

This addresses poor generalization in graph classification for AI/ML practitioners, offering an incremental improvement by applying causal theory to attention-based GNNs.

The paper tackles the problem of graph neural networks (GNNs) learning spurious correlations from noncausal features, which harms generalization, by proposing a Causal Attention Learning (CAL) strategy that discovers causal patterns and mitigates confounding effects, showing effectiveness in experiments on synthetic and real-world datasets.

In graph classification, attention and pooling-based graph neural networks (GNNs) prevail to extract the critical features from the input graph and support the prediction. They mostly follow the paradigm of learning to attend, which maximizes the mutual information between the attended graph and the ground-truth label. However, this paradigm makes GNN classifiers recklessly absorb all the statistical correlations between input features and labels in the training data, without distinguishing the causal and noncausal effects of features. Instead of underscoring the causal features, the attended graphs are prone to visit the noncausal features as the shortcut to predictions. Such shortcut features might easily change outside the training distribution, thereby making the GNN classifiers suffer from poor generalization. In this work, we take a causal look at the GNN modeling for graph classification. With our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It tricks the classifier to learn spurious correlations that facilitate the prediction in in-distribution (ID) test evaluation, while causing the performance drop in out-of-distribution (OOD) test data. To endow the classifier with better interpretation and generalization, we propose the Causal Attention Learning (CAL) strategy, which discovers the causal patterns and mitigates the confounding effect of shortcuts. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. We then parameterize the backdoor adjustment of causal theory -- combine each causal feature with various shortcut features. It encourages the stable relationships between the causal estimation and prediction, regardless of the changes in shortcut parts and distributions. Extensive experiments on synthetic and real-world datasets demonstrate the effectiveness of CAL.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes