Skip to content

QuantizedCache first token processing is counterintuitive / worse than in papers #35185

@goodevening13

Description

@goodevening13

System Info

transformers==4.46.3
torch==2.5.1
(though, it does not depend on library versions)

Who can help?

@SunMarc @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

TL;DR current quantized KV cache implementation works poorly with the first token, causes problems due to attention sinks. I describe this in detail in 'expected behavior'.

To reproduce the problem, simply run LLM inference (e.g. Llama 3.2 3B) with QuantoQuantizedCache (or other QuantizedCache descendants), e.g.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4})

Note that in this example, the quantized cache should only quantize tokens that are (buffer size) positions away. However, it mistakenly quantized the first token right away. This is important because the first token often acts like an Attention Sink and quantizing it substantially reduces model accuracy.

I describe the problem and demonstrate the quality drawdown in detail in the "Expected Behavior" section (below).

Expected behavior

What exactly is wrong

The problem occurs every time inference is run with QuantizedCache or its descendants.

The current implementation of QuantizedCache always quantizes the very first token, adding it to the q_cache instead of storing it in the full precision cache. However, according to the docstring, quantization should only occur "when the length goes beyond maximum capacity, and the original precision cache is discarded and moved into the quantized cache." This implies that the first token should only be quantized when the cache length reaches its maximum capacity.

This issue may affect model quality, as recent papers on KV quantization emphasize the importance of the first token, also known as the "attention sink" ( https://arxiv.org/abs/2309.17453 ). In fact, some papers suggest not quantizing this token at all (e.g. https://arxiv.org/abs/2410.05265 ). Failing to store the first token in full precision could therefore be problematic.

Proposed fix

The fix involves modifying the "update" method so it keeps the first token in the FP16 buffers:

https://www.diffchecker.com/Rj5qaOCj/

Here's the full code of fixed QuantizedCache that I used for benchmarks below.

Effect on model quality:

To verify that this fix work, I compared the two implemantations in terms of WikiText-2 perplexity (lower is better). In the table below 'Before fix' evaluates QuantoQuantizedCache and with the current code (see versions above) and 'After fix' uses the patch described in the previous section. These were measured with Llama-3.2-3B backbone using 2-bit and 4-bit with all default hyperparameters (see details below) .

WikiText-2 PPL FP16 4-bit 2-bit
Before fix 6.9865 7.0052 12.2152
After fix 6.9865 7.0005 11.4623

The full code for obtaining these results can be found in gist link. If any additional evaluations or explanations are required, I am happy to provide those.


Thanks to surkovvv for his help with the evaluation codebase

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions