Official PyTorch implementation of ICLR 2020 paper: A Neural Dirichlet Process Mixture Model for Task-Free Continual Learning.
Method | Split-MNIST Acc. (%) |
Split-MNIST (Gen.) bits/dim |
MNIST-SVHN Acc.(%) |
Split-CIFAR10 Acc.(%) |
Split-CIFAR100 Acc.(%) |
iid-offline | 98.63 | 0.1806 | 96.69 | 93.17 | 73.80 |
iid-online | 96.18 | 0.2156 | 95.24 | 62.79 | 20.46 |
Fine-tune | 19.43 | 0.2817 | 83.35 | 18.08 | 2.43 |
Reservoir | 85.69 | 0.2234 | 94.12 | 44.00 | 10.01 |
CN-DPM | 93.23 | 0.2110 | 94.46 | 45.21 | 20.10 |
- Python >= 3.6.1
- CUDA >= 9.0 supported GPU with at least 10GB memory
-
Install PyTorch 1.0.1 and TorchVision 0.2.2 for your environment. Follow the instructions in HERE.
-
Install other required packages.
$ pip install -r requirements.txt
$ python main.py --help
usage: main.py [-h] [--config CONFIG] [--episode EPISODE] [--log-dir LOG_DIR] [--override OVERRIDE]
optional arguments:
-h, --help show this help message and exit
--config CONFIG, -c CONFIG
--episode EPISODE, -e EPISODE
--log-dir LOG_DIR, -l LOG_DIR
--override OVERRIDE
We provide a quick and easy solution to compose continual learning scenarios. You can configure a scenario by writing a YAML file. Here is an example of Split-CIFAR10 where each stage is repeated for ten epochs:
- subsets: [['cifar10', 0], ['cifar10', 1]]
epochs: 10
- subsets: [['cifar10', 2], ['cifar10', 3]]
epochs: 10
- subsets: [['cifar10', 4], ['cifar10', 5]]
epochs: 10
- subsets: [['cifar10', 6], ['cifar10', 7]]
epochs: 10
- subsets: [['cifar10', 8], ['cifar10', 9]]
epochs: 10
Basic rules:
- Each scenario consists of a list of stages.
- Each stage defines a list of subsets.
- A subset is a two-element list
[dataset_name, subset_name]
. By default, each class is defined as a subset with the class number as its name. - Each stage may optionally define one of
epochs
,steps
, andsamples
to set the length of the stage. Otherwise, the default length is set to 1 epoch.
The main logic is implemented in the DataScheduler
in data.py
.
Run the commands below to reproduce our experimental results. You can check the summaries on TensorBoard.
$ python main.py \
--config configs/mnist_gen-iid_offline.yaml \
--episode episodes/mnist-iid-100epochs.yaml \
--log-dir log/mnist_gen-iid_offline
$ python main.py \
--config configs/mnist_gen-iid_online.yaml \
--episode episodes/mnist-iid-online.yaml \
--log-dir log/mnist_gen-iid_online
$ python main.py \
--config configs/mnist_gen-iid_online.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist_gen-finetune
$ python main.py \
--config configs/mnist_gen-reservoir.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist_gen-reservoir
$ python main.py \
--config configs/mnist_gen-cndpm.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist_gen-cndpm
$ python main.py \
--config configs/mnist-iid_offline.yaml \
--episode episodes/mnist-iid-100epochs.yaml \
--log-dir log/mnist-iid_offline
$ python main.py \
--config configs/mnist-iid_online.yaml \
--episode episodes/mnist-iid-online.yaml \
--log-dir log/mnist-iid_online
$ python main.py \
--config configs/mnist-iid_online.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist-finetune
$ python main.py \
--config configs/mnist-reservoir.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist-reservoir
$ python main.py \
--config configs/mnist-cndpm.yaml \
--episode episodes/mnist-split-online.yaml \
--log-dir log/mnist-cndpm
$ python main.py \
--config configs/mnist_svhn-iid_offline.yaml \
--episode episodes/mnist_svhn-iid-10epochs.yaml \
--log-dir log/mnist_svhn-iid_offline
$ python main.py \
--config configs/mnist_svhn-iid_online.yaml \
--episode episodes/mnist_svhn-iid-online.yaml \
--log-dir log/mnist_svhn-iid_online
$ python main.py \
--config configs/mnist_svhn-iid_online.yaml \
--episode episodes/mnist_svhn-online.yaml \
--log-dir log/mnist_svhn-finetune
$ python main.py \
--config configs/mnist_svhn-reservoir.yaml \
--episode episodes/mnist_svhn-online.yaml \
--log-dir log/mnist_svhn-reservoir
$ python main.py \
--config configs/mnist_svhn-cndpm.yaml \
--episode episodes/mnist_svhn-online.yaml \
--log-dir log/mnist_svhn-cndpm
$ python main.py \
--config configs/cifar10-iid_offline.yaml \
--episode episodes/cifar10-iid-100epochs.yaml \
--log-dir log/cifar10-iid_offline
$ python main.py \
--config configs/cifar10-iid_online.yaml \
--episode episodes/cifar10-iid-online.yaml \
--log-dir log/cifar10-iid_online
$ python main.py \
--config configs/cifar10-iid_online.yaml \
--episode episodes/cifar10-split-online.yaml \
--log-dir log/cifar10-finetune
$ python main.py \
--config configs/cifar10-reservoir.yaml \
--episode episodes/cifar10-split-online.yaml \
--log-dir log/cifar10-reservoir
$ python main.py \
--config configs/cifar10-cndpm.yaml \
--episode episodes/cifar10-split-online.yaml \
--log-dir log/cifar10-cndpm
$ python main.py \
--config configs/cifar10-cndpm.yaml \
--episode episodes/cifar10-split-0.2epoch.yaml \
--log-dir log/cifar10-cndpm-0.2epoch
$ python main.py \
--config configs/cifar10-cndpm.yaml \
--episode episodes/cifar10-split-10epochs.yaml \
--log-dir log/cifar10-cndpm-10epoch
$ python main.py \
--config configs/cifar100-iid_offline.yaml \
--episode episodes/cifar100-iid-100epochs.yaml \
--log-dir log/cifar100-iid_offline
$ python main.py \
--config configs/cifar100-iid_online.yaml \
--episode episodes/cifar100-iid-online.yaml \
--log-dir log/cifar100-iid_online
$ python main.py \
--config configs/cifar100-iid_online.yaml \
--episode episodes/cifar100-split-online.yaml \
--log-dir log/cifar100-finetune
$ python main.py \
--config configs/reservoir-resnet_classifier-cifar100.yaml \
--episode episodes/cifar100-split-online.yaml \
--log-dir log/cifar100-reservoir
$ python main.py \
--config configs/cifar100-cndpm.yaml \
--episode episodes/cifar100-split-online.yaml \
--log-dir log/cifar100-cndpm