Description
System Info
transformers
version: 4.46.0- Platform: macOS-15.0-arm64-arm-64bit
- Python version: 3.11.10
- Huggingface_hub version: 0.26.1
- Safetensors version: 0.4.5
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.0 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Problem: T5Attention forward pass fails when not using KV cache.
Caused by cache_position
being None
here. @zucchini-nlp
Code to reproduce:
import torch
from transformers.models.t5.modeling_t5 import T5Attention
from transformers.models.t5 import T5Config
import json
from huggingface_hub import hf_hub_download
T5_REPO = "google-t5/t5-small"
BATCH_SIZE = 2
tgt_len = 3
EMBED_SIZE = 512
CONFIG_NAME = "config.json"
t5_config_path = hf_hub_download(repo_id=T5_REPO, filename=CONFIG_NAME)
with open(t5_config_path, "r") as f:
t5_config = json.load(f)
t5_config = T5Config.from_dict(t5_config)
xq = torch.randn((BATCH_SIZE, tgt_len, EMBED_SIZE))
torch_t5_mha = T5Attention(t5_config).eval()
with torch.no_grad():
attn_out, kv_state, pos_bias = torch_t5_mha(xq)
Stack trace
{
"name": "TypeError",
"message": "'NoneType' object is not subscriptable",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 29
27 torch_t5_mha = T5Attention(t5_config, has_relative_attention_bias=True).eval()
28 with torch.no_grad():
---> 29 attn_out, kv_state, pos_bias = torch_t5_mha(xq)
File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/dev/attention/.venv/lib/python3.11/site-packages/transformers/models/t5/modeling_t5.py:525, in T5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions, cache_position)
523 key_length = key_states.shape[-2]
524 # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
--> 525 real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
526 if not self.has_relative_attention_bias:
527 position_bias = torch.zeros(
528 (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
529 )
TypeError: 'NoneType' object is not subscriptable"
}
Expected behavior
T5Attention forward pass should not fail when not using the KV cache, and the use_cache
flag should actually affect if the cache is used or not, which it currently doesnt
Activity