Skip to content

Conversation

@weak-kajuma
Copy link
Contributor

What does this PR do?

This PR adds the codes for the DiffLlama, which is Llama model with Differential Transformer. Please refer to Differential Transformer. @ArthurZucker

@weak-kajuma
Copy link
Contributor Author

I am coding now, but it's first time I contribute transformers and other OSS. I may ask you some help.

@weak-kajuma
Copy link
Contributor Author

I still have a error located in modeling_diffllama.py@377: apply_rotary_pos_emb. Var "query_states" must be torch.Size([2, 32, 10, 128]) but the var is torch.Size([2, 64, 10, 64]). I need to change "query_states" or "cos"&"sin".

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I think this would be an awesome fit to use modular transfomresr!
A bit of doc here: https://huggingface.co/docs/transformers/en/modular_transformers

this would help isolating the changes!

@weak-kajuma
Copy link
Contributor Author

I've finished making normal/eager Attention, and I can run with AutoModelforForCausalLM.generate().
But I'll adapt it for FlashAttention2 and Sdpa Attention.

@weak-kajuma
Copy link
Contributor Author

And also I fixed to fit modular transfomres.

weak-kajuma and others added 8 commits October 20, 2024 11:52
You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward.

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place.
and more visible

Co-authored-by: Minho Ryu <[email protected]>
Copy link
Contributor

@bzantium bzantium left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented flash and sdpa attention as well.

@weak-kajuma
Copy link
Contributor Author

@bzantium
I found Attention missed implemented from paper still on e072544.
So I'll revert to e072544 and re-implement with your suggested code style.

@Cyrilvallez
Copy link
Member

Hey, sorry for the delay!
In order to use modular transformers, you need to create a new file, modular_diffllama.py, in which you can use inheritance from the different Llama classes. Then, to automatically create the modeling_diffllama.py file, just use our CLI: python utils/modular_model_converter.py --files_to_parse src/transformers/models/diffllama/modular_diffllama.py from the root of the transformers repo 🤗
LMK if you need more guidance for this! You can find some modular example, e.g. here
Basically, any class similar to a Llama class you can directly inherit from to avoid rewriting it, e.g. if DiffLlamaRotaryEmbedding is similar to LlamaRotaryEmbedding, you can use

class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
    pass

in the modular file. In your case, you will probably need to only rewrite the attention classes 😉

@effortprogrammer
Copy link

effortprogrammer commented Nov 30, 2024

Are you still working on this PR, @weak-kajuma ?

@weak-kajuma
Copy link
Contributor Author

@Cyrilvallez Could you review again? I made modular_diffllama.py.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! A great first modular! But you can still cut a lot of code, the only difference here are the attention classes so it's perfect for modular to pick up on everything by itself!
LMK if you run into any issues

@Cyrilvallez
Copy link
Member

You may need to rebase/merge on main though for modular to work perfectly as you seem to be a bit far behind. If something does not work as expected after my comments, you should try that first 🤗

@weak-kajuma
Copy link
Contributor Author

@Cyrilvallez Could you review again? Moduler transformers is very easy and good. And also I can pass all tests by merging latest changes.

@effortprogrammer
Copy link

@Cyrilvallez any plannings to review this pr?

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, very good! Final comments 🤗

Comment on lines 54 to 66
class DiffLlamaRMSNorm(LlamaRMSNorm):
pass


ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm)


class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass


class DiffLlamaMLP(MistralMLP):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remove DiffLlamaMLP, then AttributeError: 'DiffLlamaConfig' object has no attribute 'mlp_bias' has happened. So I cannot remove it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very nice! The only thing missing is to update based on #35235 ! If you don't want we'll just open a PR afterwards!

self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

self.lambda_init = lambda_init_fn(layer_idx)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go in _init_weights() AFAIK!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't know _init_weights(). How does they move?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really a weight initialization, just declaration of a parameter so it should be ok as-is

@weak-kajuma
Copy link
Contributor Author

The main change of #35235 is about Attention, I know. But I may not be able to change differential attention like #35235. You are so busy, but I want you to make PR.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then I think it should be ready to merge, we'll take it from there! Thanks a lot for the contribution! cc @ArthurZucker

@ArthurZucker
Copy link
Collaborator

Cool let's merge then! 🤗

@ArthurZucker ArthurZucker merged commit 96bf3d6 into huggingface:main Jan 7, 2025
23 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for iterating and congrats and the model merge 🥳

AlanPonnachan pushed a commit to AlanPonnachan/transformers that referenced this pull request Jan 7, 2025
* first adding diffllama

* add Diff Attention and other but still with errors

* complate make attention Diff-Attention

* fix some bugs which may be caused by transformer-cli while adding model

* fix a bug caused by forgetting KV cache...

* Update src/transformers/models/diffllama/modeling_diffllama.py

You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward.

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

new codes are more meaningful than before

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fix 2times divide by sqrt(self.head_dim)

Co-authored-by: Minho Ryu <[email protected]>

* Update src/transformers/models/diffllama/modeling_diffllama.py

fit to changeing "num_heads // 2" place.
and more visible

Co-authored-by: Minho Ryu <[email protected]>

* I found Attention missed implemented from paper still on e072544.

* re-implemented

* adding groupnorm

Co-authored-by: Minho Ryu <[email protected]>

* align with transformers code style

Co-authored-by: Minho Ryu <[email protected]>

* fix typo

Co-authored-by: Minho Ryu <[email protected]>

* adding groupnorm

Co-authored-by: Minho Ryu <[email protected]>

* change SdpaAttention to DiffSdpaAttention

Co-authored-by: Minho Ryu <[email protected]>

* fix bug

* Update src/transformers/models/diffllama/modeling_diffllama.py

resolve "not same outputs" problem

Co-authored-by: Minho Ryu <[email protected]>

* fix bugs of places of "GroupNorm with scale" and etc

* Revert "fix bugs of places of "GroupNorm with scale" and etc"

This reverts commit 26307d9.

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* simplify multiple of attention (matmul) operations into one by repeating value_states

Co-authored-by: Minho Ryu <[email protected]>

* remove missed type

* add diffllama model_doc

* apply make style/quality

* apply review comment about model

* apply review comment about test

* place diffllama alphabetically on the src/transformers/__init__.py

* fix forgot code

* Supports parameters that are not initialized with standard deviation 0 in the conventional method

* add DiffLlamaConfig to CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK on utils/check_config_docstrings.py

* remove unused property of config

* add to supported model list

* add to spda supported model list

* fix copyright, remove pretraining_tensor_parallel, and modify for initialization test

* remove unused import and etc.

* empty commit

* empty commit

* empty commit

* apply modular transformers but with bugs

* revert prev commit

* create src/transformers/model/diffllama/modular_diffllama.py

* run utils/modular_model_converter.py

* empty commit

* leaner modular diffllama

* remove more and more in modular_diffllama.pt

* remove more and more in modular_diffllama.pt

* resolve missing docstring entries

* force reset

* convert modular

---------

Co-authored-by: Minho Ryu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants