-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an example that train the torchtitan version of llama. #8400
Conversation
c1418fd
to
bacce0b
Compare
bacce0b
to
c0c10f1
Compare
195a6e1
to
ecbdbb0
Compare
tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" | ||
os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args) | ||
|
||
_setup_default_env() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test that changing the environ here actually gets picked up by XLA? For example I was wondering if import jax
or some other import will cause the TPU backend to get initialized and ignore future changes to XLA flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm... moving to the top to be safe.
Few bugs fixed along the way: