Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] - Enable speculative decoding with batch size >1 #32189

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

kamilakesbi
Copy link
Contributor

@kamilakesbi kamilakesbi commented Jul 24, 2024

What does this PR do?

This PR aims at solving issue #32165.

I've started adapting code to enable speculative decoding with batch_size >1. I've reused some of the work in former PR #26875.

Main steps of the solution:

When batch size > 1:

  1. Compute the number of similar tokens between candidate tokens and tokens obtained after doing a forward pass with the main model. This results in a tensor n_matches with the number of matches for each sequence in the batch.

  2. We keep the matching tokens. For that, we keep all tokens from the output to the main model with a sequence position inferior to n_matches.max() + 1. In doing so, we also retain some potential mismatched tokens, which we will deal with in the next steps using padding tokens. The resulting tensor, input_ids, is thus in the form (batch_size, n_matches + 1).

  3. We shift each sequence i in input_ids by n_matches.max() - n_matches[i]. The matching tokens are displaced to the right of input_ids and n_matches.max() - n_matches[i] padding tokens are added to the left.

  4. Left cut: We cut all columns that contain only padding tokens. By design, these columns are to the left of the input_ids. In this way we keep the smallest possible `input_ids' that contain all the information needed to continue assisted generation.

Steps 1 to 4 are the main addition to the the original speculative decoding loop described in detail in this blog to enable assisted generation with BS > 1.

To make this work, we also need to adapt the computation of the attention_masks, past_key_values and position_ids to take into account the shifted positions of the generated tokens.

To do:

For now, I want to make this work with Whisper using this snippet.

I've implemented steps 1 to 4 and adapted the computation of the attention_masks and past_key_values to handle the new padding tokens.

I still need to make some adaptations with the position_ids to make this work properly. From what I can see:

  • For the main model (here WhisperForConditionalGeneration), position_ids are inferred directly from the attention_mask as we can see here. So if we pass the right attention mask to generate we should be good.

  • For the assistant model (WhisperForCausalLM in our example), position_ids are currently not computed nor passed to prepare_inputs_for_generation, which I'm not sure exactly why. I've done a first attempt at solving this with no success so far.

cc @sanchit-gandhi @gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kamilakesbi kamilakesbi changed the title [WIP] - Enable speculative decoding with batch size >1 #32165 [WIP] - Enable speculative decoding with batch size >1 Jul 24, 2024
@gante
Copy link
Member

gante commented Jul 27, 2024

Awesome 🙌 plz ping me when you have questions or when the PR is ready!

Don't forget to add tests, and, if possible, benchmarks in the PR for future reference 🙏

@xu1998hz
Copy link

Very interested in this! Please ping me when it is done

@deafTim
Copy link

deafTim commented Oct 3, 2024

Do you have something new?

@deafTim
Copy link

deafTim commented Oct 3, 2024

I have an error
@ylacombe

  File "D:\_LLM_project\Development\python311envassistbatchkamila\Lib\site-packages\transformers\modeling_attn_mask_utils.py", line 137, in to_4d
    expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (32) must match the size of tensor b (28) at non-singleton dimension 3

@deafTim
Copy link

deafTim commented Oct 8, 2024

@ylacombe
Could you please give an advice have I can fix that?

@LysandreJik
Copy link
Member

cc @ylacombe regarding this PR

@deafTim
Copy link

deafTim commented Oct 9, 2024

cc @ylacombe regarding this PR

thanks

@deafTim
Copy link

deafTim commented Oct 9, 2024

@ylacombe
Could you help, please?

@deafTim
Copy link

deafTim commented Oct 11, 2024

gpt 2 models and whisper model don't work

@trotsky1997
Copy link

Any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants