Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Llama2 7B FSDP demo #165

Open
wants to merge 76 commits into
base: main
Choose a base branch
from
Open

Add Llama2 7B FSDP demo #165

wants to merge 76 commits into from

Conversation

abhibyreddi
Copy link
Collaborator

@abhibyreddi abhibyreddi commented Oct 24, 2024

Demo code performs data loading, listing, and checkpoint saving with Dataflux. Follow up PR will add support for loading the saved checkpoints with Dataflux.

  • Tests pass - tested manually on a GCE VM with GPUs
  • Appropriate changes to documentation are included in the PR

@abhibyreddi abhibyreddi marked this pull request as ready for review November 2, 2024 01:05
@abhibyreddi abhibyreddi requested a review from a team as a code owner November 2, 2024 01:05
@abhibyreddi abhibyreddi requested review from Yash9060, jdnurme and MattIrv and removed request for Yash9060 November 2, 2024 01:05
@abhibyreddi
Copy link
Collaborator Author

Matt & JD, all the files introduced in this PR are borrowed from here. I've called out the changes I made in the README file.

I tested this by training the 7B model on the sample dataset (~24GB) for 500 iterations. I also added the output I got when I prompted the generated model (it's very bad but it works!)

demo/llama/dataset.py Show resolved Hide resolved
The code in this directory trains the Llama 7B model on [Huggingface's RedPajama dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/tree/main). All the code in this directory is borrowed from [Lightning AI's lit-llama repo](https://github.com/Lightning-AI/lit-llama). Changes have been made in appropriate places to fetch the data from GCS instead of reading from disk.

This demo has been tested on a GCE instance with `2` Nvidia `H100` GPUs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's scope the demo to just fsdp checkpointing, if possible.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made changes to support saving checkpoints with Dataflux. Updated README with details. Could you take another look?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The readme doesn't appear to cover checkpointing at all. Can you update it?

I think the training portion is not that interesting from readme purposes, since it doesn't involve Dataflux at all (does it?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a note in the "Run the pre-training script" section about implementing a custom strategy that makes it possible to save checkpoints using Dataflux. What other details would be good to mention here?

jdnurme and others added 9 commits November 8, 2024 19:27
* add simple llama load benchmark and results

* address comments

* comments
* Update save/load print statements

* Account for save/load only
* Remove model and path arguments from FSDP strategy constructors

* Fix incorrect argument order in gcs writer/reader
* Added initial naive Async FSDP strategy

* refactor script to support multiple trainer.fit executions for average execution time.

* Add logging statement to async_save. Add return docstring to updated init_process() method.

* Added a section to the README with details on how to run the async demo.

* - Refactor DatafluxFSDPStrategy to use arg for async behavior instead of creating a child class.
- Refactor the save_checkpoint_helper to just modify the checkpoint dict. This allows easier control over save/async_save in DatafluxFSDPStrategy.
- Add more benchmark logging.
- Use a custom model class for adding simulated blocking workflow.

* fix typo

* get rank from env var after it's set instead of returning. Fix var naming.

* remove return type hint.

* Updated README to reflect review feedback

* Improve readme docs.

* more doc cleanup and logging improvements.

* Use env var for accessing rank. Update logging strings.

* Fix docstring accuracy.

* further docstring fix

* remove irrelevant optimizer choice
* updated readme to include async and multinode features

* Address comments
* Update multi-node benchmarking readme

* Add commands for all the strategies currently supported

* Update numbering

* Add more info

* Fix typo

* Reword

* Fix typos

* Address comments

* Resolve merge conflicts

* Fix note

* Fix note

* Fix note

* Fix note

* Fix note

* Remove note

* Fix type

* Add a note about gcsfuse delployment

* Fix headings

* Final commit
@abhibyreddi abhibyreddi changed the title Add FSDP demo Add Llama2 7B FSDP demo Nov 13, 2024
@abhibyreddi
Copy link
Collaborator Author

@MattIrv, @jdnurme , @Yash9060 want to call your attention to demo/llama/strategies.py. The DatafluxFSDPStrategy class there is similar to the class with the same name in demo/lightning/checkpoint/multinode/strategies.py.

The one introduced in this PR inherits lightning.fabric.strategies.FSDPStrategy. The existing one inherits lightning.pytorch.strategies.FSDPStrategy. They are similar but not the same.

demo/llama/README.md Outdated Show resolved Hide resolved
demo/llama/dataset.py Outdated Show resolved Hide resolved
self._async_save(converted_state, path, writer)
else:
self._save(converted_state, path, writer)
duration_ms = (time.time() - start_time) / 1000
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ANy specific reason why we need to have time in millisecond ? I think everywhere else in the codebase we use seconds ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been copied over from demo/lightning/checkpoint/multinode/strategies.py.

self.checkpoint_group = dist.new_group(
default_ranks, backend=self.process_group_backend)

def save_checkpoint(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a link to save_checkpoint source code ? (Also I think this is coming from lightning fabric, any reason why we are using lightning fabric save_checkpoint instead of simple lightning checkpoint ?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lit-llama's training code does not use lightning.Trainer and fabric takes only custom strategies that inherit one of the classes defined in lightning.fabric.strategies. Another option I had was to re-write the training code to use lightning's LightningModule. This seemed simpler to do.

The code in this directory trains the Llama 7B model on [Huggingface's RedPajama dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/tree/main). All the code in this directory is borrowed from [Lightning AI's lit-llama repo](https://github.com/Lightning-AI/lit-llama). Changes have been made in appropriate places to fetch the data from GCS instead of reading from disk.

This demo has been tested on a GCE instance with `2` Nvidia `H100` GPUs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The readme doesn't appear to cover checkpointing at all. Can you update it?

I think the training portion is not that interesting from readme purposes, since it doesn't involve Dataflux at all (does it?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to mostly be a copy of https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/packed_dataset.py. Is there some way to only override the pieces we need from that code without reimplementing everything?

This might apply for strategies.py and train.py here too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid adding yet another reimplementation of these strategies and reuse the existing ones?

gradient_accumulation_iters, devices)


def train(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the underlying implementation be reused here? It looks like there might not be any changes



@torch.no_grad()
def validate(fabric: L.Fabric, model: torch.nn.Module,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here and for several of the functions below, does this really need to be reimplemented?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants