-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Description
System Info
transformers==4.46.3
torch==2.5.1
(though, it does not depend on library versions)
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
TL;DR current quantized KV cache implementation works poorly with the first token, causes problems due to attention sinks. I describe this in detail in 'expected behavior'.
To reproduce the problem, simply run LLM inference (e.g. Llama 3.2 3B) with QuantoQuantizedCache (or other QuantizedCache descendants), e.g.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4})
Note that in this example, the quantized cache should only quantize tokens that are (buffer size) positions away. However, it mistakenly quantized the first token right away. This is important because the first token often acts like an Attention Sink and quantizing it substantially reduces model accuracy.
I describe the problem and demonstrate the quality drawdown in detail in the "Expected Behavior" section (below).
Expected behavior
What exactly is wrong
The problem occurs every time inference is run with QuantizedCache or its descendants.
The current implementation of QuantizedCache always quantizes the very first token, adding it to the q_cache instead of storing it in the full precision cache. However, according to the docstring, quantization should only occur "when the length goes beyond maximum capacity, and the original precision cache is discarded and moved into the quantized cache." This implies that the first token should only be quantized when the cache length reaches its maximum capacity.
This issue may affect model quality, as recent papers on KV quantization emphasize the importance of the first token, also known as the "attention sink" ( https://arxiv.org/abs/2309.17453 ). In fact, some papers suggest not quantizing this token at all (e.g. https://arxiv.org/abs/2410.05265 ). Failing to store the first token in full precision could therefore be problematic.
Proposed fix
The fix involves modifying the "update" method so it keeps the first token in the FP16 buffers:
https://www.diffchecker.com/Rj5qaOCj/
Here's the full code of fixed QuantizedCache that I used for benchmarks below.
Effect on model quality:
To verify that this fix work, I compared the two implemantations in terms of WikiText-2 perplexity (lower is better). In the table below 'Before fix' evaluates QuantoQuantizedCache and with the current code (see versions above) and 'After fix' uses the patch described in the previous section. These were measured with Llama-3.2-3B backbone using 2-bit and 4-bit with all default hyperparameters (see details below) .
| WikiText-2 PPL | FP16 | 4-bit | 2-bit |
|---|---|---|---|
| Before fix | 6.9865 | 7.0052 | 12.2152 |
| After fix | 6.9865 | 7.0005 | 11.4623 |
The full code for obtaining these results can be found in gist link. If any additional evaluations or explanations are required, I am happy to provide those.
Thanks to surkovvv for his help with the evaluation codebase