Introduction
Yann is a batteries included deep learning framework built in PyTorch.
Inspired by Django and Rails, it aims to automate the tedious steps of a machine learning project so that you can focus on the (fun) hard parts. It makes it easy to quickly get a project started and scales with you all the way to production.
It could also be viewed as torch.nn
extended, as it includes common new research modules
that might be too experimental to be included in torch.
Getting Started
Install
1 | pip install yann |
Quick Tour
Flexible Trainer
Yann provides a Trainer
class that encapsulates your experiment state and handles common tasks
such as object instantiation, checkpointing and progress tracking.
Default Training Loop
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | from yann.train import Trainer from yann.transforms import Compose, ToTensor, Resize, Normalize train = Trainer( model='resnet50', # could also be an instance dataset='Imagenette160', transform=Compose([ Resize(224), Normalize('imagenet'), ToTensor() ]), batch_size=32, optimizer='AdamW', loss='cross_entropy', metrics=('accuracy', 'top_3_accuracy'), device='cuda:0' ) # run training for 5 epochs train(epochs=5) # save checkpoint, including model and optimizer state train.checkpoint(name='{time}-{loss}-{steps}.th') train.history.plot() |
Custom Logic
It exposes methods for common tasks such as iterating over data batches and checkpointing, making it a convenient state container for more complicated uses cases.
As an example we can implement Accelerating Deep Learning by Focusing on the Biggest Losers using an inverted loop
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | train.checkpoint.load('latest') for epoch in train.epochs(5): # keeps track of epochs, and starts iteration from current epoch (even after loading a checkpoint) for inputs, targets in train.batches(): # yield training batches from the train data loader with yann.optim_step(train.optimizer): # calls optimizer.zero_grad() and optimizer.step() for you with torch.no_grad(): outputs = train.model(inputs) losses = train.loss(outputs, targets, reduction='none') _, top_indices = losses.topk(12) outputs = train.model(inputs[top_indices]) loss = train.loss(outputs, targets[top_indices]) loss.backward() train.checkpoint() |
which could also be done by passing a step function when initializing the trainer (Trainer(step=step_on_top_losses)
) or by using the override
decorator
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | @train.override('step') def step_on_top_losses(train: Trainer, inputs, targets): train.model.train() train.optimizer.zero_grad() with torch.no_grad(): outputs = train.model(inputs) losses = train.loss(outputs, targets, reduce='none') _, top_indices = losses.topk(12) outputs = train.model(inputs[top_indices]) loss = train.loss(outputs, targets[top_indices]) loss.backward() train.optimizer.step() |
Callbacks
Inspired by Keras, the trainer supports functional or class based callbacks that make it easy to integrate additional event handlers during the training process.
Function Based
1 2 3 4 5 6 7 | @train.on('batch_end') def plot_scores(inputs, targets, outputs, loss, **kwargs): yann.plot.scores(outputs) @train.on('batch_error') def handle_error(error): ... |
Class Based
1 2 3 4 5 6 | from yann.callbacks import HistoryPlotter train = Trainer(callbacks=(HistoryPlotter(),)) # or add it later train.callbacks.append(HistoryPlotter()) |
Experiment Tracking and Reproducibility
To help you track your experiments and keep things reproducible, the trainer automatically tracks your git hash, python dependencies, logs and checkpoints in train.paths.root
.
Hyperparamter Definition
1 2 3 4 5 6 7 8 9 10 11 12 | from yann.params import HyperParams, Choice, Range class Params(HyperParams): dataset = 'MNIST' batch_size = 32 epochs = 10 optimizer: Choice(('SGD', 'Adam')) = 'SGD' learning_rate: Range(.01, .0001) = .01 momentum = 0 seed = 1 |
Automatic Command Line Interface
1 2 | # parse command line arguments params = Params.from_command() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | usage: train_mnist.py [-h] [-o {SGD,Adam}] [-lr LEARNING_RATE] [-d DATASET] [-bs BATCH_SIZE] [-e EPOCHS] [-m MOMENTUM] [-s SEED] optional arguments: -h, --help show this help message and exit -o {SGD,Adam}, --optimizer {SGD,Adam} optimizer (default: SGD) -lr LEARNING_RATE, --learning_rate LEARNING_RATE learning_rate (default: 0.01) -d DATASET, --dataset DATASET dataset (default: MNIST) -bs BATCH_SIZE, --batch_size BATCH_SIZE batch_size (default: 32) -e EPOCHS, --epochs EPOCHS epochs (default: 10) -m MOMENTUM, --momentum MOMENTUM momentum (default: 0) -s SEED, --seed SEED seed (default: 1) |