⚠️ NOTE: this has been moved to https://github.com/pytorch-labs/torchft⚠️
Prototype repo for PyTorch fault tolerance
This implements a lighthouse server that coordinates across the different replica groups and then a per replica group manager and fault tolerance library that can be used in a standard PyTorch training loop.
This allows for membership changes at the training step granularity which can greatly improve efficiency by avoiding stop the world training on errors.
$ pip install .
This uses pyo3+maturin to build the package, you'll need maturin installed.
To install in editable mode w/ the Rust extensions you can use the normal pip install command:
$ pip install -e .
The lighthouse is used for fault tolerance across replicated workers (DDP/FSDP) when using synchronous training.
You can start a lighthouse server by running:
$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000
See train_ddp.py for the full example.
Invoke with:
$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py
train.py:
from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo
manager = Manager(
pg=ProcessGroupGloo(),
load_state_dict=...,
state_dict=...,
)
m = nn.Linear(2, 3)
m = DistributedDataParallel(manager, m)
optimizer = Optimizer(manager, optim.AdamW(m.parameters()))
for i in range(1000):
batch = torch.rand(2, 2, device=device)
optimizer.zero_grad()
out = m(batch)
loss = out.sum()
loss.backward()
optimizer.step()
torchft has a fault tolerant parameter server implementation built on it's reconfigurable ProcessGroups. This does not require/use a Lighthouse server.
See parameter_server_test.py for an example.
$ cargo fmt
% cargo test
BSD 3-Clause -- see LICENSE for more details.
Copyright (c) Tristan Rice 2024