Skip to content
This repository has been archived by the owner on Oct 24, 2024. It is now read-only.
/ torchft Public archive

Prototype repo for PyTorch fault tolerance

License

Notifications You must be signed in to change notification settings

d4l3k/torchft

Repository files navigation

⚠️ NOTE: this has been moved to https://github.com/pytorch-labs/torchft ⚠️

torch-ft

Prototype repo for PyTorch fault tolerance

This implements a lighthouse server that coordinates across the different replica groups and then a per replica group manager and fault tolerance library that can be used in a standard PyTorch training loop.

This allows for membership changes at the training step granularity which can greatly improve efficiency by avoiding stop the world training on errors.

Installation

$ pip install .

This uses pyo3+maturin to build the package, you'll need maturin installed.

To install in editable mode w/ the Rust extensions you can use the normal pip install command:

$ pip install -e .

Lighthouse

The lighthouse is used for fault tolerance across replicated workers (DDP/FSDP) when using synchronous training.

You can start a lighthouse server by running:

$ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 1000

Example Training Loop (DDP)

See train_ddp.py for the full example.

Invoke with:

$ TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train.py

train.py:

from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo

manager = Manager(
    pg=ProcessGroupGloo(), 
    load_state_dict=...,
    state_dict=...,
)

m = nn.Linear(2, 3)
m = DistributedDataParallel(manager, m)
optimizer = Optimizer(manager, optim.AdamW(m.parameters()))

for i in range(1000):
    batch = torch.rand(2, 2, device=device)

    optimizer.zero_grad()

    out = m(batch)
    loss = out.sum()

    loss.backward()

    optimizer.step()

Example Parameter Server

torchft has a fault tolerant parameter server implementation built on it's reconfigurable ProcessGroups. This does not require/use a Lighthouse server.

See parameter_server_test.py for an example.

Running Tests / Lint

$ cargo fmt
% cargo test

License

BSD 3-Clause -- see LICENSE for more details.

Copyright (c) Tristan Rice 2024

About

Prototype repo for PyTorch fault tolerance

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages