LGAIJun 7, 2024

Learning Divergence Fields for Shift-Robust Graph Representations

arXiv:2406.04963v14 citations
Originality Incremental advance
AI Analysis

This work addresses the problem of distribution shift in graph data for machine learning applications, presenting an incremental improvement through generalized versions of existing models.

The paper tackles the challenge of generalizing learning models to out-of-distribution data in interdependent graph structures by proposing a geometric diffusion model with learnable divergence fields, demonstrating promising efficacy on diverse real-world datasets.

Real-world data generation often involves certain geometries (e.g., graphs) that induce instance-level interdependence. This characteristic makes the generalization of learning models more difficult due to the intricate interdependent patterns that impact data-generative distributions and can vary from training to testing. In this work, we propose a geometric diffusion model with learnable divergence fields for the challenging generalization problem with interdependent data. We generalize the diffusion equation with stochastic diffusivity at each time step, which aims to capture the multi-faceted information flows among interdependent data. Furthermore, we derive a new learning objective through causal inference, which can guide the model to learn generalizable patterns of interdependence that are insensitive across domains. Regarding practical implementation, we introduce three model instantiations that can be considered as the generalized versions of GCN, GAT, and Transformers, respectively, which possess advanced robustness against distribution shifts. We demonstrate their promising efficacy for out-of-distribution generalization on diverse real-world datasets.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes