-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable StaticCache for assisted generation #34797
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Matrix YAO <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: Matrix YAO <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@gante , could you pls take a look? thx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yao-matrix hey, gante is currently on a long vacation so I reviewed the PR for him. Thanks for adding support for this, Super cool work!
I left a few comments and also we'll need tests in tests/generation/test_utils.py
file. I guess static cache now works with all types of candidate generators right?
src/transformers/generation/utils.py
Outdated
if assistant_model is not None: | ||
assistant_model._get_cache( | ||
cache_implementation=generation_config.cache_implementation, | ||
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, | ||
max_cache_len=max_cache_length, | ||
device=device, | ||
model_kwargs=model_kwargs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I think it will be called on assistant model when we call assistant.generate()
so there is no need. We can only remove self.generation_config.cache_implementation = None
in candidate generator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the thing is: when we leave to let assistant_model.generate
which is in get_candiates
to call this. since the max_new _tokens will be set to max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
when it's first-time called, so the cache_length will be set to int(self.num_assistant_tokens) + prompt_len
, less than the actual needed cache_length max_token_length + prompt_length
, and lead to assert out while generation. So, the key here is assistant model's cache length should be same as main model here. And then I can see this function has assistant_model as an argument but not used it, I think it may be here for the cases like this. That's the rational behind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i see, that makes sense. Then we can leave cache init here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! We need some tests and then I am requesting review from the core maintainer, after that we can merge
@zucchini-nlp , test_utils CI pass rate is the same before and after this PR, as below. So no regressions are introduced. after: |
thx for reviewing. |
@yao-matrix no worries is some tests are failing and are not related to PR changes. Might be just flaky or will be fixed on
|
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@zucchini-nlp , any more comments for me to iterate? Thx. |
@yao-matrix no, the only thing is the CI which is failing now. I showed the relevant test in prev comment and if you can add one more test in At the end you need to run |
@parameterized.expand([(None, True), ("static", False)]) | ||
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache): | ||
if cache_implementation == "static": | ||
self.skipTest("Gemma2 has HybridCache which is not compatible with assisted decoding StaticCache") | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not skip entirely, but only the static_cache
test, as we still need to check if assisted generation works in Gemma2 :)
Maybe it will be skipped by the model._support_static_cache
as I've commented above, but if not we can skip only the test_assisted_decoding_with_num_logits_to_keep_1_static
(maybe it's called a bit differently)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i switch to _supports_static_cache
to skip the case. For Gemma, it's a bit different, since it's using HybridCache and claims _supports_static_cache = True
, I still skip it in model test file. Will remove this skip after enable HybridCache for assisted decoding, I plan to enable it after this PR(pure StaticCache) merged, thx.
Signed-off-by: root <[email protected]>
Signed-off-by: root <[email protected]>
Signed-off-by: root <[email protected]>
…ctually Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice, but we need to add a compile
test to make sure this is compile compatible! The whole point of static cache is -> compile! 🤗
@ArthurZucker i added a |
Signed-off-by: N <[email protected]>
Signed-off-by: N <[email protected]>
@ArthurZucker @zucchini-nlp , pls let me know any further comments, thx. BTW, checked the failed ci case, not relevant to my changes. |
Thanks, re-triggered the tests, let's wait for the core maintainer |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@ArthurZucker , @zucchini-nlp , I am thinking is it possible we leave this PR in 2024, :). |
@zucchini-nlp @ArthurZucker , any further comments on this? |
@yao-matrix Arthur Zucker might be a bit busy, so tagging @gante to review is possible. We'll be able to merge after we get one more review |
thx, @gante , could you pls help review? Thx. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, IMO would be better to update the test and also to make sure we use static cache for the assisted model when we use it for the parent model + auto compile in those cases!
end = seq_length + 1 | ||
index = torch.arange(begin, end, device=self.key_cache[0].device) | ||
|
||
self._seen_tokens = max_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
relying on the _seen_tokens
should not be relied on! the static cache does not really rely on this
""" | ||
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests | ||
end-to-end compilation and forward pass compilation only. | ||
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of performances I think it would make more sense to test if we can compile the forward of the model and the forward of the assistant model instead of this! Our recent focus has been on rather bridging the gap here as compilation of generate
is super super slow!
I would thus also look at the changes introduced by #34247 to have something similar for assisted model! 🤗
@gante , I implemented a version for this issue: #32946. Pls help comment, and I can iterate, thx.