Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
siqim committed May 15, 2022
1 parent 7d3cfad commit 37a9c36
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# GSAT
The official implementation of Graph Stochastic Attention (GSAT) for our paper: [Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism](https://arxiv.org/abs/2201.12987).
The official implementation of Graph Stochastic Attention (GSAT) for our paper: [Interpretable and Generalizable Graph Learning via Stochastic Attention Mechanism](https://arxiv.org/abs/2201.12987), to appear in ICML 2022.

## Introduction
Commonly used attention mechanisms do not impose any constraints during training, and thus may lack interpretability. GSAT is a novel attention mechanism to build interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Figure 1 shows the architecture of GSAT.
Commonly used attention mechanisms do not impose any constraints during training (besides normalization), and thus may lack interpretability. GSAT is a novel attention mechanism for building interpretable graph learning models. It injects stochasticity to learn attention, where a higher attention weight means a higher probability of the corresponding edge being kept during training. Such a mechanism will push the model to learn higher attention weights for edges that are important for prediction accuracy, which provides interpretability. To further improve the interpretability for graph learning tasks and avoid trivial solutions, we derive regularization terms for GSAT based on the information bottleneck (IB) principle. As a by-product, IB also helps model generalization. Fig. 1 shows the architecture of GSAT.

<p align="center"><img src="./data/arch.png" width=85% height=85%></p>
<p align="center"><em>Figure 1.</em> The architecture of GSAT.</p>
Expand Down Expand Up @@ -32,12 +32,12 @@ pip install -r requirements.txt


# Run Examples
We provide examples with minimal code to run GSAT in `./example/example.ipynb`. We have tested the provided examples on `Ba-2Motifs (GIN)`, `Mutag (GIN)` and `OGBG-Molhiv (PNA)`.
We provide examples with minimal code to run GSAT in `./example/example.ipynb`. We have tested the provided examples on `Ba-2Motifs (GIN)`, `Mutag (GIN)` and `OGBG-Molhiv (PNA)`. Yet, to implement GSAT* one needs to load a pre-trained model first in the provided example.

It should be able to run on other datasets as well, but some hard-coded hyperparameters might need to be changed accordingly. To reproduce results for other datasets, please follow the instructions in the following section.

# Reproduce Results
We provide the source code to reproduce the results in our paper.
We provide the source code to reproduce the results in our paper. The results of GSAT can be reproduced by running `run_gsat.py`. To reproduce GSAT*, one needs to run `pretrain_clf.py` first and change the configuration file accordingly (`from_scratch: false`).

To pre-train a classifier:
```
Expand Down Expand Up @@ -96,10 +96,17 @@ All settings can be found in `./src/configs`.

# FAQ
#### Does GSAT encourage sparsity?
No, GSAT doesn't encourage generating sparse subgraphs. We find `r = 0.7` (Eq.9 in our paper) can generally work well for all datasets in our experiments, which means during training roughly `70%` of edges will be kept (kind of still dense). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph (this is what previous works normally do). Instead, it provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.
No, GSAT doesn't encourage generating sparse subgraphs. We find `r = 0.7` (Eq.(9) in our paper) can generally work well for all datasets in our experiments, which means during training roughly `70%` of edges will be kept (kind of still dense). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph (this is what previous works normally do). Instead, it provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

#### How to choose the value of `r`?
A grid search in `[0.5, 0.6, 0.7, 0.8, 0.9]` is recommended, but `r = 0.7` is a good starting point. Note that in practice we would decay the value of `r` gradually during training from `0.9` to the chosen value.

#### `p` or `α` in Eq.(9)?
Recall in Fig. 1, `p` is the probability of dropping an edge, while `α` is the sampled result from `Bern(p)`. In our provided implementation, as an empirical choice, `α` is used to implement Eq.(9). We find that when `α` is used it may provide more regularization and makes the model more robust to hyperparameters. Nonetheless, using `p` can achieve the same performance, but it needs some more tuning.

#### Can you show an example of how GSAT works?
Below we show an example from the `ba_2motifs` dataset, which is to distinguish five-node cycle motifs (left) and house motifs (right).
To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to `1`). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to `r`, which is set to be `0.7` in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped).
To make good predictions (minimize the cross-entropy loss), GSAT will push the attention weights of those critical edges to be relatively large (ideally close to `1`). Otherwise, those critical edges may be dropped too frequently and thus result in a large cross-entropy loss. Meanwhile, to minimize the regularization loss (the KL divergence term in Eq.(9) of the paper), GSAT will push the attention weights of other non-critical edges to be close to `r`, which is set to be `0.7` in the example. This mechanism of injecting stochasticity makes the learned attention weights from GSAT directly interpretable, since the more critical an edge is, the larger its attention weight will be (the less likely it can be dropped). Note that `ba_2motifs` satisfies our Thm. 4.1 with no noise, and GSAT achieves perfect interpretation performance on it.

<p align="center"><img src="./data/example.png" width=85% height=85%></p>
<p align="center"><em>Figure 2.</em> An example of the learned attention weights.</p>
<p align="center"><em>Figure 2.</em> An example of the learned attention weights.</p>

0 comments on commit 37a9c36

Please sign in to comment.