This repository contains a PyTorch implementation of the paper 🧠 From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data, which is published in MIDL 2024.
If you find this repository useful, please consider giving a star 🌟 and citing the paper:
@inproceedings{li2024triplet,
title={From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data},
author={Li, Yitong and Wolf, Tom Nuno and P{\"o}lsterl, Sebastian and Yakushev, Igor and Hedderich, Dennis M and Wachinger, Christian},
booktitle={Medical Imaging with Deep Learning},
year={2024}
}
We used data from UK Biobank, Alzheimer's Disease Neuroimaging Initiative (ADNI), and the Frontotemporal Lobar Degeneration Neuroimaging Initiative (NIFD) for self-supervised learning and self-distillation. Since we are not allowed to share our data, you would need to process the data yourself. Data for training, validation, and testing should be stored in separate HDF5 files, using the following hierarchical format:
- First level: A unique identifier, e.g. image ID.
- The second level always has the following entries:
- A group named
MRI/T1
, containing the T1-weighted 3D MRI data. - A string attribute
DX
containing the diagnosis labels:CN
,AD
orFTD
, if available. - A scalar attribute
RID
with the patient ID, if available. - Additional attributes depending on the task, such as
Sex
andAge
, if available.
- A group named
- Create environment:
conda env create -n triplet --file requirements.yaml
- Activate environment:
conda activate triplet
- Install
addiagnosis
package in development mode:pip install --no-deps -e .
The package uses PyTorch, PyTorch Lightning and Hydra. PyTorch Lightning is a lightweight PyTorch wrapper. Hydra's key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. It allows you to conveniently manage experiments.
The Python modules are located in the src/addiagnosis
directory,
the Hydra configuration files are in the configs
directory, where configs/train.yaml
is
the main config for training (self-supervised learning and self-distillation), and configs/train_transfer.yaml
for transfer learning on the your downstream tasks. Specify the pretrained_model
path in the config files to continue training the next step with the pretrained backbone from the previous step.
After specifying the config files, simply start training (self-supervised learning or self-distillation) by:
python train.py
transfer learning on your downstream task:
python transfer_learning.py
and testing on your downstream task:
python test.py
For any questions, please contact: Yitong Li ([email protected])
The self-supervised learning part of the codes were adopted into 3D implementation from Barlow Twins, VICReg, SupContrast, DiRA. I used rfs as a reference for the self-distillation implementation.