-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ESM
] Add support for sdpa.
#34954
base: main
Are you sure you want to change the base?
[ESM
] Add support for sdpa.
#34954
Conversation
Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code! |
Thanks for your reply, I will add relevant test cases soon. |
8f7773d
to
996880a
Compare
@Rocketknight1 Hello, the sdpa inference tests for ESMFold has been added. Could you please review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks like a good SDPA addition to me! I'll also set up slow tests in a sec.
Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hello @Rocketknight1, I found the test failures were due to the device mismatching of the I will report this to |
Hello @Rocketknight1 , I made a quick fix according to other model's test, the test cases should work normally now. |
Yes, looks good to me now! cc @ArthurZucker @LysandreJik for core maintainer review |
@ArthurZucker @LysandreJik Hello! Can you please help review this pr? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿
Sure, I will do it soon. |
What does this PR do?
Add support for SDPA (scaled dot product attention) for ESM. More context in #28802 (And this pr mainly reused the code from this pr as the ESM is Bert-based model) and #28005 .
This is my first time contributing to this project, please point out if there is any mistakes.
And revert a change in #29329 as the dtype-mismatching issue for bitsandbytes is actually caused by the rotary embedding.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker