Skip to content

Conversation

@OmarManzoor
Copy link
Contributor

What does this PR do?

Towards #34809

  • Adds Flex Attention for Mistral
  • Does refactoring to enable the attention mechanisms using functions instead of classes

Who can review?

@ArthurZucker

if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
if ("SdpaAttention" in class_name or "SdpaSelfAttention" in class_name) or (
hasattr(submodule, "_uses_attention_functions") and submodule._uses_attention_functions
):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure how to handle this correctly.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

Comment on lines +369 to +370
if self._attn_implementation != "flash_attention_2":
cache_kwargs["cache_position"] = cache_position
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this escape is required no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is how it is handled in FlashAttention2

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

@OmarManzoor
Copy link
Contributor Author

@ArthurZucker What should be do about these failing tests? I think they are related to the sdpa tests where we might have output_attentions equal to True.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 27, 2024

Hi. Let's take a look on failing test(s) step by step.

First, do you know why Idefics2 will have something MistralAttention? It's very strange. (see the test_torch job log)

FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_retain_grad_hidden_states_attentions - AttributeError: 'MistralAttention' object has no attribute 'scaling'

@OmarManzoor
Copy link
Contributor Author

Hi. Let's take a look on failing test(s) step by step.

First, do you know why Idefics2 will have something MistralAttention? It's very strange. (see the test_torch job log)

FAILED tests/models/idefics2/test_modeling_idefics2.py::Idefics2ForConditionalGenerationModelTest::test_retain_grad_hidden_states_attentions - AttributeError: 'MistralAttention' object has no attribute 'scaling'

Idefics2 uses mistral as the text model

self.text_model = AutoModel.from_config(config.text_config)
        if isinstance(text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            logger.info("text_config is None, using default text config")
            text_config = CONFIG_MAPPING["mistral"](
                max_position_embeddings=4096 * 8,
                rms_norm_eps=1e-5,
                # None in the original configuration_mistral, we set it to the unk_token_id
                pad_token_id=0,
                tie_word_embeddings=False,
            )

@ArthurZucker
Copy link
Collaborator

BTW let's make sure we rebase now that #34896 was merged!

@ArthurZucker
Copy link
Collaborator

can you make sure the CIs are green? 🤗

@OmarManzoor
Copy link
Contributor Author

can you make sure the CIs are green? 🤗

Should I reset the default back to eager instead of flex because the eager matches sdpa fails for float32 when using flex. Or do we need to change the thresholds to ensure that flex remains the default while the tests are also green?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! We actually shipped this in #35235 ! 🤗 super sorry for the late notice

@OmarManzoor
Copy link
Contributor Author

Hey! We actually shipped this in #35235 ! 🤗 super sorry for the late notice

Thanks for informing.

@OmarManzoor OmarManzoor deleted the mistral_flex branch December 23, 2024 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants