Skip to content

Conversation

@jmamou
Copy link
Contributor

@jmamou jmamou commented Oct 14, 2024

What does this PR do?

Following
#33258
#33657

The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante @amyeroberts

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the comments: can you share the benchmark results here as well, for future reference?

@jmamou
Copy link
Contributor Author

jmamou commented Oct 16, 2024

In addition to the comments: can you share the benchmark results here as well, for future reference?

A100
target: starcoder
draft: tiny_starcoder
dataset: MBPP

Heuristics: mean_inference_time=16.33ms
Fixed threshold for dynamic SL #33258: mean_inference_time=14.03ms
Adaptive threshold for dynamic SL (current PR): mean_inference_time=13.42ms

I will run later benchmark from https://huggingface.co/blog/dynamic_speculation_lookahead

@gante
Copy link
Contributor

gante commented Oct 17, 2024

@jmamou yeah, let's please run more benchmarks before (iterating on the PR and) merging

In the odd chance it ends up being beneficial only in very specific circumstances, I'd rather not merge the technique to avoid adding complexity (which usually reduces our team's ability to work on more projects 🤗 )

@jmamou
Copy link
Contributor Author

jmamou commented Oct 31, 2024

@jmamou yeah, let's please run more benchmarks before (iterating on the PR and) merging

In the odd chance it ends up being beneficial only in very specific circumstances, I'd rather not merge the technique to avoid adding complexity (which usually reduces our team's ability to work on more projects 🤗 )

@gante
I have run benchmarks from https://huggingface.co/spaces/joaogante/assisted_generation_benchmarks
https://github.com/gante/huggingface-demos/tree/main/experiments/faster_generation
Evaluated metric: throughput -- time per token in ms, lower is better
Device: RTX 3090; dtype applies to both models

Model Assistant dtype task sampling? w/o assistant disco 0.4 adaptive disco disco speedup adaptive disco speedup
openai/whisper-large-v2 openai/whisper-tiny fp16 automatic speech recognition no 20.02 14.59 13.81 1.37 1.45
facebook/opt-6.7b facebook/opt-125m bf16 summarization no 23.81 8.73 8.72 2.73 2.73
facebook/opt-6.7b facebook/opt-125m bf16 summarization yes (t=0,6) 24.21 12.01 10.55 2.02 2.29
facebook/opt-6.7b facebook/opt-125m bf16 open-ended generation no 22.14 14.19 14.14 1.56 1.57
facebook/opt-6.7b facebook/opt-125m bf16 open-ended generation yes (t=0,7) 22.13 14.16 14.09 1.56 1.57
Salesforce/codegen-6B-mono Salesforce/codegen-350M-mono bf16 code generation (python) no 30.88 26.8 26.95 1.15 1.15
Salesforce/codegen-6B-mono Salesforce/codegen-350M-mono bf16 code generation (python) yes (t=0,4) 37.02 35.88 33.79 1.03 1.1
google/flan-t5-xl google/flan-t5-small bf16 summarization no 24.76 20.11 20.1 1.23 1.23
google/flan-t5-xl google/flan-t5-small bf16 summarization yes (t=0,6) 24.44 26.78 25.15 0.91 0.97
Model Assistant dtype task sampling?          
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 summarization no 33.06 19.27 19.29 1.72 1.71
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 summarization yes (t=0,6) 33.6 24.35 21.69 1.38 1.55
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 open-ended generation no 31.25 33.2 33.1 0.94 0.94
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 open-ended generation yes (t=0,7) 31.35 42.29 39.02 0.74 0.8
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 code generation (python) no 27.98 19.49 19.72 1.44 1.42
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B bf16 code generation (python) yes (t=0,4) 28.6 24.23 20.85 1.18 1.37

An improvement is observed when do_sample=True, likely because the threshold was set to 0.4 to optimize for greedy decoding. It seems that a lower threshold may be needed when sampling, highlighting the need to adapt the threshold as proposed in the PR ...

@gante
Copy link
Contributor

gante commented Nov 4, 2024

@jmamou I'm convinced :D The benchmarks do show a consistent upgrade

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the thorough benchmark 🤗

@gante gante requested a review from ArthurZucker November 4, 2024 10:34
@jmamou
Copy link
Contributor Author

jmamou commented Nov 20, 2024

@ArthurZucker could you please review it?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! It's just missing a test !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! It's just missing a test !

@jmamou
Copy link
Contributor Author

jmamou commented Dec 4, 2024

LGTM! It's just missing a test !

@ArthurZucker
done!

@ArthurZucker
Copy link
Collaborator

from transformers.generation.candidate_generator import AssistedCandidateGenerator needs to protect it's import to torch! candidate_generator.py needs to check python availability!

@ArthurZucker
Copy link
Collaborator

Let's GOOOOOOO! 🚀

@ArthurZucker ArthurZucker merged commit e27465c into huggingface:main Dec 5, 2024
22 checks passed
@jmamou jmamou deleted the adaptive-SL branch May 28, 2025 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants