IA-GCN: Interpretable Attention based Graph Convolutional Network for Disease prediction
This addresses the need for interpretable AI in medical diagnostics to assist clinical experts, though it is incremental as it builds on existing GCN interpretability methods.
The paper tackles the problem of interpretability in Graph Convolutional Networks for medical disease prediction by proposing an interpretable attention module that explains feature relevance and improves model performance, achieving average accuracy increases of 3.2% on Tadpole and up to 2% on UKBB datasets.
Interpretability in Graph Convolutional Networks (GCNs) has been explored to some extent in computer vision in general, yet, in the medical domain, it requires further examination. Moreover, most of the interpretability approaches for GCNs, especially in the medical domain, focus on interpreting the model in a post hoc fashion. In this paper, we propose an interpretable graph learning-based model which 1) interprets the clinical relevance of the input features towards the task, 2) uses the explanation to improve the model performance and, 3) learns a population level latent graph that may be used to interpret the cohort's behavior. In a clinical scenario, such a model can assist the clinical experts in better decision-making for diagnosis and treatment planning. The main novelty lies in the interpretable attention module (IAM), which directly operates on multi-modal features. Our IAM learns the attention for each feature based on the unique interpretability-specific losses. We show the application on two publicly available datasets, Tadpole and UKBB, for three tasks of disease, age, and gender prediction. Our proposed model shows superior performance with respect to compared methods with an increase in an average accuracy of 3.2% for Tadpole, 1.6% for UKBB Gender, and 2% for the UKBB Age prediction task. Further, we show exhaustive validation and clinical interpretation of our results.