Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Zamba2 #34517

Draft
wants to merge 73 commits into
base: main
Choose a base branch
from
Draft

Add Zamba2 #34517

wants to merge 73 commits into from

Conversation

pglorio
Copy link
Contributor

@pglorio pglorio commented Oct 30, 2024

What does this PR do?

Please include support for Zamba2 architecture created by Zyphra Technologies.

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@pglorio pglorio marked this pull request as draft October 30, 2024 17:57
@pglorio
Copy link
Contributor Author

pglorio commented Nov 11, 2024

Hey @Arthur,

Thank you again for your help in getting Zamba2 into transformers! The PR is now finally ready to be reviewed. I added the documentation and all unit tests pass, including slow tests.

A few remarks, mostly related to modular transformers:

  1. To generate modeling and configuration I used utils/modular_model_converter.py from a previous commit because the most recent version of this script that followed from a large refactoring produces an error that I was not able to fix:
Converting src/transformers/models/zamba2/modular_zamba2.py to a single model single file format
Traceback (most recent call last):
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1510, in <module>
    converted_files = convert_modular_file(file_name, args.old_model_name, args.new_model_name)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1447, in convert_modular_file
    for file, module in create_modules(cst_transformers).items():
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1387, in create_modules
    nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1337, in get_class_node_and_dependencies
    new_node_dependencies, new_imports = check_dependencies_and_create_import_node(
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in check_dependencies_and_create_import_node
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
  File "/workspace/transformers_zamba/utils/modular_model_converter.py", line 1283, in <setcomp>
    class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())}
KeyError: 'Zamba2Config'

I carefully compared Zamba2Config with classes of other models that also use modular (such as Gemma2Config) and they appear to have consistent format. Relatedly, the utils/modular_model_converter.py in the current PR (path) is the version from the previous commit mentioned above.

  1. After running utils/modular_model_converter.py, the modeling and configuration files generated contain unintended code that I had to update. All these modifications are in this commit. In particular, the produced modeling file contains Zamba2DynamicCache, which is the correct cache of Zamba2 as well as HybridMambaAttentionDynamicCache, which is the cache of Zamba and is not relevant to Zamba2, so I deleted HybridMambaAttentionDynamicCache and related references.

  2. I ran make fixup and all zamba-related tests pass, with the exception of python utils/check_modular_conversion.py. This test doesn't pass due to the modifications mentioned in the previous point.

  3. I slightly edited the Zamba2MambaMixer compared to the original Mamba2Mixer of mamba2, the main difference is that I added these lines, which was necessary to appropriately process the mamba2 cache (note this step already existed in the torch forward in these lines).

Looking forward to your feedback. Thanks so much!

@pglorio
Copy link
Contributor Author

pglorio commented Dec 20, 2024

Thank you so much for this feedback @Cyrilvallez. I realized that the issue with unit tests was inside the torch_forward method of the mamba2 mixer (when i ran locally, the unit tests used cuda_kernels method instead). I fixed that method here: 1 2 3.

By the way, we originally took the the torch_forward from the mamba2 model, so the same issues hold there. In particular, running this:

config = Mamba2Config(num_heads=8,
        n_groups=8,
        state_size=2,
        head_dim=8,
        conv_kernel=4,
        chunk_size=8,
        vocab_size=99,
        hidden_size=32,
        num_hidden_layers=4,
        hidden_act="silu",
        hidden_dropout_prob=0.1,
        max_position_embeddings=512,
                      )
model = Mamba2ForCausalLM(config)

inputs = {'input_ids': torch.tensor([[86,  6, 51,  3, 12, 15, 33, 18,  4, 92],
         [69, 66, 49, 45, 48, 44, 61, 56, 68, 85]]),
 'attention_mask': torch.tensor([[0, 0, 1, 1, 1, 0, 0, 0, 1, 0],
         [0, 1, 1, 0, 1, 1, 1, 0, 1, 1]])}

outputs_cpu = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
model = model.to('cuda')
inputs = {key: tensor.to(device=model.device) for key, tensor in inputs.items()}
outputs_cuda = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
print(torch.all(outputs_cpu == outputs_cuda.cpu()).item())

returns False.

@pglorio pglorio mentioned this pull request Jan 7, 2025
5 tasks
@pglorio
Copy link
Contributor Author

pglorio commented Jan 14, 2025

Hi @Cyrilvallez and @ArthurZucker,

I updated the attention forward to the new standard of transformers here and here.

I ran all final tests, including @slow tests, and everything appears to pass!

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work for the refactor! Almost ready, left some final comments but overall quite nice! 🤗

src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
src/transformers/models/zamba2/modular_zamba2.py Outdated Show resolved Hide resolved
Comment on lines 1415 to 1417
"ZambaModelTester",
"Zamba2ModelTester",
"RwkvModelTester",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ydshieh here to ensure this change is necessary, as I'm not familiar with this new part!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh for context, when running this test the config of the model is forced to have num_hidden_layers=1 but other parameters of the config are not updated accordingly so when the model is initialized it errors out as these params are not consistently updated. It's probably also the reason why Zamba was added to this list I imagine.

tests/models/zamba2/test_modeling_zamba2.py Outdated Show resolved Hide resolved
tests/models/zamba2/test_modeling_zamba2.py Outdated Show resolved Hide resolved
tests/models/zamba2/test_modeling_zamba2.py Show resolved Hide resolved
@pglorio
Copy link
Contributor Author

pglorio commented Jan 16, 2025

Thank you @Cyrilvallez for the review. I addressed the comments above, although there are a couple of pending points.

All zamba-related tests appear to pass.

@pglorio
Copy link
Contributor Author

pglorio commented Jan 17, 2025

Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here.

Separately, when running utils/check_modular_conversion.py I get the following error:

Differences found between the generated code and src/transformers/models/zamba2/modeling_zamba2.py:

   1 --- src/transformers/models/zamba2/modeling_zamba2.py_generated
   2 +++ src/transformers/models/zamba2/modeling_zamba2.py
   3 @@ -313,6 +313,13 @@
   4      return attn_output, attn_weights
   5  
   6  
   7 +def rotate_half(x):
   8 +    """Rotates half the hidden dims of the input."""
   9 +    x1 = x[..., : x.shape[-1] // 2]
  10 +    x2 = x[..., x.shape[-1] // 2 :]
  11 +    return torch.cat((-x2, x1), dim=-1)
  12 +
  13 +
  14  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  15      """Applies Rotary Position Embedding to the query and key tensors.
  16  
  17 @@ -338,13 +345,6 @@
  18      q_embed = (q * cos) + (rotate_half(q) * sin)
  19      k_embed = (k * cos) + (rotate_half(k) * sin)
  20      return q_embed, k_embed
  21 -
  22 -
  23 -def rotate_half(x):
  24 -    """Rotates half the hidden dims of the input."""
  25 -    x1 = x[..., : x.shape[-1] // 2]
  26 -    x2 = x[..., x.shape[-1] // 2 :]
  27 -    return torch.cat((-x2, x1), dim=-1)

which I was not getting before despite this part was identical.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants