Skip to content

Official PyTorch Implementation for From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data - MIDL 2024

License

Notifications You must be signed in to change notification settings

ai-med/TripletTraining

Repository files navigation

TripletTraining

This repository contains a PyTorch implementation of the paper 🧠 From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data, which is published in MIDL 2024.

If you find this repository useful, please consider giving a star 🌟 and citing the paper:

@inproceedings{li2024triplet,
  title={From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data},
  author={Li, Yitong and Wolf, Tom Nuno and P{\"o}lsterl, Sebastian and Yakushev, Igor and Hedderich, Dennis M and Wachinger, Christian},
  booktitle={Medical Imaging with Deep Learning},
  year={2024}
}

Data

We used data from UK Biobank, Alzheimer's Disease Neuroimaging Initiative (ADNI), and the Frontotemporal Lobar Degeneration Neuroimaging Initiative (NIFD) for self-supervised learning and self-distillation. Since we are not allowed to share our data, you would need to process the data yourself. Data for training, validation, and testing should be stored in separate HDF5 files, using the following hierarchical format:

  1. First level: A unique identifier, e.g. image ID.
  2. The second level always has the following entries:
    1. A group named MRI/T1, containing the T1-weighted 3D MRI data.
    2. A string attribute DX containing the diagnosis labels: CN, AD or FTD, if available.
    3. A scalar attribute RID with the patient ID, if available.
    4. Additional attributes depending on the task, such as Sex and Age, if available.

Installation

  1. Create environment: conda env create -n triplet --file requirements.yaml
  2. Activate environment: conda activate triplet
  3. Install addiagnosis package in development mode: pip install --no-deps -e .

Usage

The package uses PyTorch, PyTorch Lightning and Hydra. PyTorch Lightning is a lightweight PyTorch wrapper. Hydra's key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. It allows you to conveniently manage experiments.

The Python modules are located in the src/addiagnosis directory, the Hydra configuration files are in the configs directory, where configs/train.yaml is the main config for training (self-supervised learning and self-distillation), and configs/train_transfer.yaml for transfer learning on the your downstream tasks. Specify the pretrained_model path in the config files to continue training the next step with the pretrained backbone from the previous step.

After specifying the config files, simply start training (self-supervised learning or self-distillation) by:

python train.py

transfer learning on your downstream task:

python transfer_learning.py

and testing on your downstream task:

python test.py

Contacts

For any questions, please contact: Yitong Li ([email protected])

Acknowlegements

The self-supervised learning part of the codes were adopted into 3D implementation from Barlow Twins, VICReg, SupContrast, DiRA. I used rfs as a reference for the self-distillation implementation.

About

Official PyTorch Implementation for From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data - MIDL 2024

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages