-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Closed
Labels
Description
System Info
transformersversion: 4.46.2- Platform: Linux-5.15.0-125-generic-x86_64-with-glibc2.35
- Python version: 3.10.15
- Huggingface_hub version: 0.26.3
- Safetensors version: 0.4.5
- Accelerate version: 1.1.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA A10
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Hi @ArthurZucker @GeLee-Q , thanks for #33312, which has resolved the inference bug for qwen2vl in float16. However, it still has some flaws. Although the current code solves the numerical overflow issue, it may break causality of the autoregressive model. See details here.
Expected behavior
Current: Code
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
Fixed: Replace ±inf with zero before adding attn_weights with causal_mask. (prevent converting -inf, which for causality, to zero.)
# Fix precision issues in Qwen2-VL float16 inference
# Replace inf values with zeros in attention weights to prevent NaN propagation
if query_states.dtype == torch.float16:
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask