-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Llama2 7B FSDP demo #165
base: main
Are you sure you want to change the base?
Conversation
Matt & JD, all the files introduced in this PR are borrowed from here. I've called out the changes I made in the README file. I tested this by training the 7B model on the sample dataset (~ |
The code in this directory trains the Llama 7B model on [Huggingface's RedPajama dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/tree/main). All the code in this directory is borrowed from [Lightning AI's lit-llama repo](https://github.com/Lightning-AI/lit-llama). Changes have been made in appropriate places to fetch the data from GCS instead of reading from disk. | ||
|
||
This demo has been tested on a GCE instance with `2` Nvidia `H100` GPUs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's scope the demo to just fsdp checkpointing, if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made changes to support saving checkpoints with Dataflux. Updated README with details. Could you take another look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The readme doesn't appear to cover checkpointing at all. Can you update it?
I think the training portion is not that interesting from readme purposes, since it doesn't involve Dataflux at all (does it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a note in the "Run the pre-training script" section about implementing a custom strategy that makes it possible to save checkpoints using Dataflux. What other details would be good to mention here?
* add simple llama load benchmark and results * address comments * comments
* Update save/load print statements * Account for save/load only
* Remove model and path arguments from FSDP strategy constructors * Fix incorrect argument order in gcs writer/reader
* Added initial naive Async FSDP strategy * refactor script to support multiple trainer.fit executions for average execution time. * Add logging statement to async_save. Add return docstring to updated init_process() method. * Added a section to the README with details on how to run the async demo. * - Refactor DatafluxFSDPStrategy to use arg for async behavior instead of creating a child class. - Refactor the save_checkpoint_helper to just modify the checkpoint dict. This allows easier control over save/async_save in DatafluxFSDPStrategy. - Add more benchmark logging. - Use a custom model class for adding simulated blocking workflow. * fix typo * get rank from env var after it's set instead of returning. Fix var naming. * remove return type hint. * Updated README to reflect review feedback * Improve readme docs. * more doc cleanup and logging improvements. * Use env var for accessing rank. Update logging strings. * Fix docstring accuracy. * further docstring fix * remove irrelevant optimizer choice
* updated readme to include async and multinode features * Address comments
* Update multi-node benchmarking readme * Add commands for all the strategies currently supported * Update numbering * Add more info * Fix typo * Reword * Fix typos * Address comments * Resolve merge conflicts * Fix note * Fix note * Fix note * Fix note * Fix note * Remove note * Fix type * Add a note about gcsfuse delployment * Fix headings * Final commit
@MattIrv, @jdnurme , @Yash9060 want to call your attention to The one introduced in this PR inherits |
self._async_save(converted_state, path, writer) | ||
else: | ||
self._save(converted_state, path, writer) | ||
duration_ms = (time.time() - start_time) / 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ANy specific reason why we need to have time in millisecond ? I think everywhere else in the codebase we use seconds ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been copied over from demo/lightning/checkpoint/multinode/strategies.py
.
self.checkpoint_group = dist.new_group( | ||
default_ranks, backend=self.process_group_backend) | ||
|
||
def save_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a link to save_checkpoint source code ? (Also I think this is coming from lightning fabric, any reason why we are using lightning fabric save_checkpoint instead of simple lightning checkpoint ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lit-llama's training code does not use lightning.Trainer
and fabric
takes only custom strategies that inherit one of the classes defined in lightning.fabric.strategies
. Another option I had was to re-write the training code to use lightning's LightningModule
. This seemed simpler to do.
The code in this directory trains the Llama 7B model on [Huggingface's RedPajama dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample/tree/main). All the code in this directory is borrowed from [Lightning AI's lit-llama repo](https://github.com/Lightning-AI/lit-llama). Changes have been made in appropriate places to fetch the data from GCS instead of reading from disk. | ||
|
||
This demo has been tested on a GCE instance with `2` Nvidia `H100` GPUs. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The readme doesn't appear to cover checkpointing at all. Can you update it?
I think the training portion is not that interesting from readme purposes, since it doesn't involve Dataflux at all (does it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to mostly be a copy of https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/packed_dataset.py. Is there some way to only override the pieces we need from that code without reimplementing everything?
This might apply for strategies.py and train.py here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to avoid adding yet another reimplementation of these strategies and reuse the existing ones?
gradient_accumulation_iters, devices) | ||
|
||
|
||
def train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the underlying implementation be reused here? It looks like there might not be any changes
|
||
|
||
@torch.no_grad() | ||
def validate(fabric: L.Fabric, model: torch.nn.Module, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here and for several of the functions below, does this really need to be reimplemented?
Demo code performs data loading, listing, and checkpoint saving with Dataflux. Follow up PR will add support for loading the saved checkpoints with Dataflux.