Skip to content

Conversation

@winglian
Copy link
Contributor

@winglian winglian commented Dec 9, 2024

What does this PR do?

Deepspeed 0.16 has assertions preventing the use of no_sync with zero 2/3. see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/engine.py#L1986-L2004

it seems people are reporting this here deepspeedai/DeepSpeed#6793, and I'm assuming that everyone is using accelerate/transformers as downgrading to deepspeed 0.15.4 makes it "work" for them.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@winglian
Copy link
Contributor Author

winglian commented Dec 9, 2024

might have to broaden deepspeed for all zero cases? https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/engine.py#L2208-L2209

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we should disable it for all DS (apparently), let's just go ahead and do that. I'll apply a similar fix in Accelerator.

cc @SunMarc

context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
if i != len(batch_samples) - 1 and not self.accelerator.distributed_type == DistributedType.DEEPSPEED

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may i ask which version of transformers support this fix-up? mine is 4.46.0. same problem with deepspeed 0.16

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with zach suggestion !

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We've also confirmed this fixes up all the fails users reported wrt deepspeed. cc @ArthurZucker for final post wing doing quality ;)

@SunMarc SunMarc requested a review from ArthurZucker December 13, 2024 13:17
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for breaking this ... and thanks for the fix!

@ArthurZucker
Copy link
Collaborator

can you just run make fixup

@ArthurZucker ArthurZucker merged commit add53e2 into huggingface:main Dec 13, 2024
1 of 5 checks passed
ArthurZucker pushed a commit that referenced this pull request Dec 13, 2024
…stages (#35157)

* don't use no_sync when deepspeed doesn't support it for certain zero stages

* chore: lint

* fix no_sync context for deepspeed across all zero types

* chore: lint
inkcherry pushed a commit to inkcherry/transformers that referenced this pull request Jan 15, 2025
…stages (huggingface#35157)

* don't use no_sync when deepspeed doesn't support it for certain zero stages

* chore: lint

* fix no_sync context for deepspeed across all zero types

* chore: lint
@AetherPrior
Copy link

Strangely, this issue still exists on deepspeed==0.16.2, has this fix been pushed to a stable release yet?

@ArthurZucker
Copy link
Collaborator

This is on 4.48 !

@fangpings
Copy link

With transformers==4.48.0, accelerate==1.2.1 and deepspeed==0.16.3, still see this issue

@jianguoz
Copy link

jianguoz commented Mar 4, 2025

Hi @ArthurZucker @muellerzr , we still face the same issue with transformers==4.48.0/4.49.0, accelerate==1.2.1 and deepspeed==0.16.3. Could you check this?

error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3

@SunMarc
Copy link
Member

SunMarc commented Mar 4, 2025

Can you share the full traceback @jianguoz and with a minimal reproducer @jianguoz thanks !

@jianguoz
Copy link

jianguoz commented Mar 4, 2025

Hi @SunMarc @ArthurZucker @muellerzr , Below is the my zero 3 config and error output when fine-tuning a mistral_small_24b or llama model on 8 GPUs. It can only work with gradient_accumulation=1. I believe people will face same issue when training a model using the latest transformer>=4.48.0 and deepspeed>=0.16.0 and accelerate==1.2.1/1.4.0.

[2025-03-04 19:03:35,551] [INFO] [config.py:991:print_user_config]   json = {
    "bf16": {
        "enabled": true
    }, 
    "optimizer": {
        "type": "AdamW", 
        "params": {
            "lr": 2e-05, 
            "weight_decay": 0
        }
    }, 
    "scheduler": {
        "type": "WarmupDecayLR", 
        "params": {
            "warmup_min_lr": 1e-06, 
            "warmup_max_lr": 2e-05, 
            "warmup_num_steps": 100, 
            "total_num_steps": 1.259000e+03
        }
    }, 
    "zero_optimization": {
        "stage": 3, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "sub_group_size": 1.000000e+09, 
        "reduce_bucket_size": 2.621440e+07, 
        "stage3_prefetch_bucket_size": 0, 
        "stage3_param_persistence_threshold": 5.120000e+04, 
        "stage3_max_live_parameters": 1.000000e+09, 
        "stage3_max_reuse_distance": 1.000000e+09, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "gradient_accumulation_steps": 3, 
    "gradient_clipping": 1.0, 
    "steps_per_print": inf, 
    "train_batch_size": 48, 
    "train_micro_batch_size_per_gpu": 2, 
    "wall_clock_breakdown": false, 
    "fp16": {
        "enabled": false
    }
}

output

step = 0 -- rank 0: error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
  0%|                                                                                                                                     | 1/10072 [01:15<210:10:00, 75.13s/it]step = 1 -- rank 0: error: step = 0 -- rank 1: error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
  0%|                                                                                                                                     | 1/10072 [01:15<210:30:16, 75.25s/it]step = 0 -- rank 2: error:step = 0 -- rank 5: error:  step = 0 -- rank 3: error:no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3step = 0 -- rank 6: error: 
step = 0 -- rank 7: error:
no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3  
no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3

step = 0 -- rank 4: error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
step = 1 -- rank 1: error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
  0%|                                                                                                                                     | 1/10072 [01:15<210:23:32, 75.21s/it]step = 1 -- rank 2: error: step = 1 -- rank 5: error:no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3 
no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
step = 1 -- rank 4: error: no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3
step = 1 -- rank 7: error:step = 1 -- rank 6: error:  no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage 3

@jianguoz
Copy link

jianguoz commented Mar 6, 2025

Hi @SunMarc , do we have updates for this issue? I saw others have same issues on #34984

@SunMarc
Copy link
Member

SunMarc commented Mar 13, 2025

cc @XuehaiPan

@jianguoz
Copy link

Hi @SunMarc , any update for this incompatible issues between Huggingface and Deepspeed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants