Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an example that train the torchtitan version of llama. #8400

Merged
merged 6 commits into from
Nov 23, 2024

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Nov 20, 2024

Few bugs fixed along the way:

  • silu.default lowering to go to direct lowering
  • Blockspec computatation should be inside of flash_attention (because the query len might change if shard_map applies).

@qihqi qihqi changed the title Add hybrid mesh Add an example that train the torchtitan version of llama. Nov 22, 2024
@qihqi qihqi marked this pull request as ready for review November 22, 2024 21:41
@qihqi qihqi requested review from tengyifei and JackCaoG November 22, 2024 21:41
tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args)

_setup_default_env()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test that changing the environ here actually gets picked up by XLA? For example I was wondering if import jax or some other import will cause the TPU backend to get initialized and ignore future changes to XLA flags.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm... moving to the top to be safe.

@qihqi qihqi requested a review from tengyifei November 22, 2024 22:26
@qihqi qihqi merged commit 31d348e into master Nov 23, 2024
3 checks passed
@qihqi qihqi deleted the hanq_hybrid_mesh branch November 23, 2024 01:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants