Towards training GNNs using explanation feedbacks

Dr Chirag Agarwal

Abstract

g28.png

Introduction. Graph Neural Networks (GNNs) are increasingly used as powerful tools for representing graph-structured data, such as social, information, chemical, and biological networks. As these models are deployed in critical applications (e.g., drug repurposing and crime forecasting), it becomes essential to ensure that the relevant stakeholders understand and trust their decisions. To this end, several approaches have been proposed in recent literature to explain the predictions of GNNs. Depending on the employed techniques, there are three broad categories: perturbation-based, gradient-based, and surrogate-based methods. While several classes of GNN explanation methods have been proposed, there is little to no work done on showing how to use these explanations to improve the GNN performance. In particular, there is no framework that leverages these explanations on the fly and aids the training of a GNN. This lack of understanding mainly stems from the fact that there is very little work on systematically analyzing the use of explanations generated by state-of-the-art GNN explanation methods 


Proposal. Previous research in GraphXAI has focused on developing post-hoc explanation methods. In this work, we propose in-hoc GNN explanations that act as feedbacks, on the fly, during the training phase of a GNN model. In addition, we aim to use the generated explanations to improve the GNN training. Using explanations, we plan to define local neighborhoods for neural message passing, e.g., for a correctly classified node u, we can generate the most important nodes in its local neighborhood N_u and then use it as a prior or generate augmented samples for guiding the message-passing of similar nodes in the subsequent training stages. To this extent, we propose to have an explanation layer after every message-passing layer which acts as a unit buffer that passes all the information to the upper layers during the forward pass, but propagates the explanation information back to the graph representations.