Skip to content

Vision Encoder-Decoder fails with LLaMA decoder due to missing cross-attention implementation #34674

Closed
@amazingvince

Description

System Info

  • transformers version: 4.46.2
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.7
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.0+cu121 (True)
  • Tensorflow version (GPU?): 2.17.0 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.8.5 (gpu)
  • Jax version: 0.4.33
  • JaxLib version: 0.4.33
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA L4

Who can help?

Not sure for multi modal models:
text models: @ArthurZucker
vision models: @amyeroberts, @qubvel
generate: @zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Description

When using a vision encoder-decoder model, there's an incompatibility when using LLaMA as the decoder. While both GPT-2 and LLaMA are decoder models, GPT-2 implements an output class that includes cross-attention outputs, whereas LLaMA's output class (CausalLMOutputWithPast) does not include this attribute. This causes the vision encoder-decoder forward pass to fail when attempting to access cross-attention outputs.

Current Behavior

The model raises an AttributeError because LLaMA's implementation of CausalLMOutputWithPast doesn't include cross-attention outputs, while the vision encoder-decoder expects this attribute to be present (as it exists in GPT-2's implementation).

Error message:

AttributeError: 'CausalLMOutputWithPast' object has no attribute 'cross_attentions'

Technical Analysis

  1. GPT-2's decoder implementation returns an output class that includes cross-attention information
  2. LLaMA's decoder implementation returns CausalLMOutputWithPast which doesn't include cross-attention
  3. The vision encoder-decoder architecture assumes the presence of cross-attention in the decoder outputs

Steps to Reproduce

  1. Initialize a vision encoder-decoder model with LLaMA as the decoder
  2. Attempt to run a forward pass or generate
  3. The error occurs in modeling_vision_encoder_decoder.py when trying to access decoder_outputs.cross_attentions

The error occurs in modeling_vision_encoder_decoder.py around line 651:

decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,  # This line causes the error
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,

Workaround

Setting cross_attentions to None allows the model to work, suggesting that the architecture doesn't strictly require this information for functioning.

Proposed Solutions

  1. Short term: Modify the vision encoder-decoder implementation to handle decoders that don't provide cross-attention outputs:
cross_attentions = getattr(decoder_outputs, 'cross_attentions', None)

Happy to submit a PR if this is an appropriate solution

Expected behavior

modeling_vision_encoder_decoder.py should support different decoder models without custom causal lm cross attention output classes.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions