-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Zamba2 #34517
base: main
Are you sure you want to change the base?
Add Zamba2 #34517
Conversation
Rebase zamba2
Hey @Arthur, Thank you again for your help in getting Zamba2 into A few remarks, mostly related to
I carefully compared
Looking forward to your feedback. Thanks so much! |
rebase on upstream
Thank you so much for this feedback @Cyrilvallez. I realized that the issue with unit tests was inside the By the way, we originally took the the config = Mamba2Config(num_heads=8,
n_groups=8,
state_size=2,
head_dim=8,
conv_kernel=4,
chunk_size=8,
vocab_size=99,
hidden_size=32,
num_hidden_layers=4,
hidden_act="silu",
hidden_dropout_prob=0.1,
max_position_embeddings=512,
)
model = Mamba2ForCausalLM(config)
inputs = {'input_ids': torch.tensor([[86, 6, 51, 3, 12, 15, 33, 18, 4, 92],
[69, 66, 49, 45, 48, 44, 61, 56, 68, 85]]),
'attention_mask': torch.tensor([[0, 0, 1, 1, 1, 0, 0, 0, 1, 0],
[0, 1, 1, 0, 1, 1, 1, 0, 1, 1]])}
outputs_cpu = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
model = model.to('cuda')
inputs = {key: tensor.to(device=model.device) for key, tensor in inputs.items()}
outputs_cuda = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False)
print(torch.all(outputs_cpu == outputs_cuda.cpu()).item()) returns |
Hi @Cyrilvallez and @ArthurZucker, I updated the attention forward to the new standard of I ran all final tests, including |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work for the refactor! Almost ready, left some final comments but overall quite nice! 🤗
"ZambaModelTester", | ||
"Zamba2ModelTester", | ||
"RwkvModelTester", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ydshieh here to ensure this change is necessary, as I'm not familiar with this new part!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh for context, when running this test the config of the model is forced to have num_hidden_layers=1
but other parameters of the config are not updated accordingly so when the model is initialized it errors out as these params are not consistently updated. It's probably also the reason why Zamba was added to this list I imagine.
Thank you @Cyrilvallez for the review. I addressed the comments above, although there are a couple of pending points. All zamba-related tests appear to pass. |
Hello @Cyrilvallez, I ran all model tests on two GPUs and after a couple of minor fixes everything appears to work now. I'm skipping this test as it gives an error related to mamba2 kernels. I indeed verified that mamba2 skips that test here. Separately, when running
which I was not getting before despite this part was identical. |
What does this PR do?
Please include support for Zamba2 architecture created by Zyphra Technologies.
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker