-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] - Enable speculative decoding with batch size >1 #32189
base: main
Are you sure you want to change the base?
[WIP] - Enable speculative decoding with batch size >1 #32189
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Awesome 🙌 plz ping me when you have questions or when the PR is ready! Don't forget to add tests, and, if possible, benchmarks in the PR for future reference 🙏 |
Very interested in this! Please ping me when it is done |
Do you have something new? |
I have an error
|
@ylacombe |
cc @ylacombe regarding this PR |
thanks |
@ylacombe |
gpt 2 models and whisper model don't work |
Any update? |
What does this PR do?
This PR aims at solving issue #32165.
I've started adapting code to enable speculative decoding with batch_size >1. I've reused some of the work in former PR #26875.
Main steps of the solution:
When batch size > 1:
Compute the number of similar tokens between candidate tokens and tokens obtained after doing a forward pass with the main model. This results in a tensor
n_matches
with the number of matches for each sequence in the batch.We keep the matching tokens. For that, we keep all tokens from the output to the main model with a sequence position inferior to
n_matches.max() + 1
. In doing so, we also retain some potential mismatched tokens, which we will deal with in the next steps using padding tokens. The resulting tensor,input_ids
, is thus in the form (batch_size
,n_matches
+ 1).We shift each sequence
i
ininput_ids
byn_matches.max() - n_matches[i]
. The matching tokens are displaced to the right ofinput_ids
andn_matches.max() - n_matches[i]
padding tokens are added to the left.Left cut: We cut all columns that contain only padding tokens. By design, these columns are to the left of the
input_ids
. In this way we keep the smallest possible `input_ids' that contain all the information needed to continue assisted generation.Steps 1 to 4 are the main addition to the the original speculative decoding loop described in detail in this blog to enable assisted generation with BS > 1.
To make this work, we also need to adapt the computation of the
attention_masks
,past_key_values
andposition_ids
to take into account the shifted positions of the generated tokens.To do:
For now, I want to make this work with Whisper using this snippet.
I've implemented steps 1 to 4 and adapted the computation of the
attention_masks
andpast_key_values
to handle the new padding tokens.I still need to make some adaptations with the
position_ids
to make this work properly. From what I can see:For the main model (here
WhisperForConditionalGeneration
),position_ids
are inferred directly from the attention_mask as we can see here. So if we pass the right attention mask togenerate
we should be good.For the assistant model (
WhisperForCausalLM
in our example),position_ids
are currently not computed nor passed to prepare_inputs_for_generation, which I'm not sure exactly why. I've done a first attempt at solving this with no success so far.cc @sanchit-gandhi @gante