Skip to content

T5Attention forward pass failing when not using KV cache #34448

Closed
@gardberg

Description

System Info

  • transformers version: 4.46.0
  • Platform: macOS-15.0-arm64-arm-64bit
  • Python version: 3.11.10
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Problem: T5Attention forward pass fails when not using KV cache.

Caused by cache_position being None here. @zucchini-nlp

Code to reproduce:

import torch
from transformers.models.t5.modeling_t5 import T5Attention
from transformers.models.t5 import T5Config
import json
from huggingface_hub import hf_hub_download

T5_REPO = "google-t5/t5-small"

BATCH_SIZE = 2
tgt_len = 3
EMBED_SIZE = 512

CONFIG_NAME = "config.json"

t5_config_path = hf_hub_download(repo_id=T5_REPO, filename=CONFIG_NAME)

with open(t5_config_path, "r") as f:
    t5_config = json.load(f)

t5_config = T5Config.from_dict(t5_config)

xq = torch.randn((BATCH_SIZE, tgt_len, EMBED_SIZE))

torch_t5_mha = T5Attention(t5_config).eval()
with torch.no_grad():
    attn_out, kv_state, pos_bias = torch_t5_mha(xq)
Stack trace

{
"name": "TypeError",
"message": "'NoneType' object is not subscriptable",
"stack": "---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 29
27 torch_t5_mha = T5Attention(t5_config, has_relative_attention_bias=True).eval()
28 with torch.no_grad():
---> 29 attn_out, kv_state, pos_bias = torch_t5_mha(xq)

File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)

File ~/dev/attention/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()

File ~/dev/attention/.venv/lib/python3.11/site-packages/transformers/models/t5/modeling_t5.py:525, in T5Attention.forward(self, hidden_states, mask, key_value_states, position_bias, past_key_value, layer_head_mask, query_length, use_cache, output_attentions, cache_position)
523 key_length = key_states.shape[-2]
524 # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
--> 525 real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
526 if not self.has_relative_attention_bias:
527 position_bias = torch.zeros(
528 (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
529 )

TypeError: 'NoneType' object is not subscriptable"
}

Expected behavior

T5Attention forward pass should not fail when not using the KV cache, and the use_cache flag should actually affect if the cache is used or not, which it currently doesnt

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions