Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease
This repository contains the code to the paper "Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease"
@inproceedings{Wolf2023-IPMI,
doi = {10.1007/978-3-031-34048-2_7},
author = {Wolf, Tom Nuno and P{\"o}lsterl, Sebastian and Wachinger, Christian},
title = {Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease},
booktitle = {Information Processing in Medical Imaging},
pages = {82--94},
year = {2023}
}
If you are using this code, please cite the paper above.
Use conda to create an environment called panic
with all dependencies:
conda env create -n panic --file requirements.yaml
Additionally, install the package torchpanic from this repository with
pip install --no-deps -e .
We used data from the Alzheimer's Disease Neuroimaging Initiative (ADNI). Since we are not allowed to share our data, you would need to process the data yourself. Data for training, validation, and testing should be stored in separate HDF5 files, using the following hierarchical format:
- First level: A unique identifier.
- The second level always has the following entries:
- A group named
PET
with the subgroupFDG
, which itself has the dataset nameddata
as child: The FDG-PET volume of size (113,117,113). Additionally, the subgroupFDG
has an attributeimageuid
with is the unique image identifier. - A group named
tabular
, which has two datasets calleddata
andmissing
, each of size 41:data
contains the tabular data values, whilemissing
is a missing value indicator if a tabular feature was not acquired at this visit. - A scalar attribute
RID
with the patient ID. - A string attribute
VISCODE
with ADNI's visit code. - A string attribute
DX
containing the diagnosis (CN
,MCI
orDementia
).
- A group named
One entry in the resulting HDF5 file should have the following structure:
/1010012 Group
Attribute: RID scalar
Type: native long
Data: 1234
Attribute: VISCODE scalar
Type: variable-length null-terminated UTF-8 string
Data: "bl"
Attribute: DX scalar
Type: variable-length null-terminated UTF-8 string
Data: "CN"
/1010012/PET Group
/1010012/PET/FDG Group
Attribute imageuid scalar
Type: variable-length null-terminated UTF-8 string
Data: "12345"
/1010012/PET/FDG/data Dataset {113, 137, 133}
/1010012/tabular Group
/1010012/tabular/data Dataset {41}
/1010012/tabular/missing Dataset {41}
Finally, the HDF5 file should also contain the following meta-information
in a separate group named stats
:
/stats/tabular Group
/stats/tabular/columns Dataset {41}
/stats/tabular/mean Dataset {41}
/stats/tabular/stddev Dataset {41}
They are the names of the features in the tabular data, their mean, and standard deviation.
PANIC processes tabular data depending on its data type.
Therefore, it is necessary to tell PANIC how to process each tabular feature:
The following indices must be given to the model in the configs file configs/model/panic.yaml
:
idx_real_features
: indices of real-valued features within tabular
data.
idx_cat_features
: indices of categorical features within tabular
data.
idx_real_has_missing
: indices of real-valued features which should be considered from missing
.
idx_cat_has_missing
: indices of categorical features which should be considered from missing
.
Similarly, missing tabular inputs to DAFT (configs/model/daft.yaml
) need to be specified with idx_tabular_has_missing
.
To train PANIC, or any of the baseline models, adapt the config files (mainly train.yaml
) and execute the train.py
script to begin training.
Model checkpoints will be written to the outputs
folder by default.
We provide some useful utility function to create plots and visualization required to interpret the model.
You can find them under torchpanic/viz
.