[Kernels] Sm120 attention correctness#6209
Open
rolson24 wants to merge 8 commits intomodular:mainfrom
Open
Conversation
BEGIN_PUBLIC [Kernel][GPU] Fix SM120 GPT-OSS attention correctness Fix ragged and decode attention behavior on SM120-class GPUs, including sink-aware softmax handling and fused RoPE correctness in the live attention path. This keeps the changes focused on kernel correctness and the associated GPU test gating. END_PUBLIC Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC [Kernel][GPU] Fix sink attention test buffer lifetime Keep the sink-weight device buffer alive across the async flash attention launch and readback in the sink test, and allow the sink path to execute on SM120 so the test covers that configuration. END_PUBLIC Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC [Kernel][GPU] Tighten SM120 sink test changes Restore the SM100-only assertion on the specialized prefill path and keep only the minimal sink-buffer lifetime fix in the sink attention test. END_PUBLIC Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC [Kernel][GPU] Split RoPE changes from attention branch Restore the RoPE kernels to main on the SM120 attention branch so the flash attention correctness work can be reviewed independently. END_PUBLIC Signed-off-by: Raif Olson <[email protected]>
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes flash-attention correctness issues on SM120-class NVIDIA GPUs, focusing on sink-aware softmax normalization, ragged/decode behavior, and MMA online-softmax state handling, plus a test stabilization tweak.
Changes:
- Correct sink-weight indexing and sink normalization in the GPU softmax kernel.
- Fix online-softmax split-warp reduction by initializing and writing back rowmax/rowsum state.
- Add NVIDIA-device fallbacks to the naive MHA path for sink/ragged cases and tighten test synchronization/lifetime behavior.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
max/kernels/test/gpu/nn/test_flash_attention.mojo |
Adds extra synchronization and an explicit lifetime “use” to prevent sink buffer early reclamation during async execution. |
max/kernels/src/nn/softmax.mojo |
Fixes sink-weight addressing (per head rather than flattened row) and corrects sink contribution to the softmax denominator; initializes online-softmax reduction state. |
max/kernels/src/nn/mha.mojo |
Introduces SM120-related correctness fallbacks to mha_gpu_naive, threads sink support through ragged entry points, and updates sink-weight loads / softmax call typing. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
549
to
+580
| if not is_token_generation: | ||
| # Correctness fallback for sink-aware prefill on affected NVIDIA | ||
| # paths. The shared kernel is currently diverging from the | ||
| # reference sink normalization semantics on SM120-class devices. | ||
| comptime if ( | ||
| has_nvidia_gpu_accelerator() | ||
| and sink | ||
| and not is_sm90 | ||
| and not is_sm100 | ||
| ): | ||
| mha_gpu_naive[ | ||
| ragged=ragged, | ||
| sink=sink, | ||
| _use_valid_length=_use_valid_length, | ||
| _is_cache_length_accurate=_is_cache_length_accurate, | ||
| ]( | ||
| q, | ||
| k, | ||
| v, | ||
| mask_functor, | ||
| output, | ||
| valid_length.value(), | ||
| scale, | ||
| batch_size, | ||
| max_prompt_len, | ||
| max_cache_valid_length, | ||
| Int(num_heads), | ||
| Int(depth), | ||
| Int(group), | ||
| ctx, | ||
| sink_weights, | ||
| ) |
Comment on lines
+787
to
+818
| # Correctness fallback for ragged decoding on affected NVIDIA | ||
| # paths. The live ragged/paged kernel output diverges from both the | ||
| # explicit score reconstruction and the CPU reference on SM120. | ||
| comptime if ( | ||
| has_nvidia_gpu_accelerator() | ||
| and (ragged or sink) | ||
| and not is_sm90 | ||
| and not is_sm100 | ||
| ): | ||
| mha_gpu_naive[ | ||
| ragged=ragged, | ||
| sink=sink, | ||
| _use_valid_length=_use_valid_length, | ||
| _is_cache_length_accurate=_is_cache_length_accurate, | ||
| ]( | ||
| q, | ||
| k, | ||
| v, | ||
| mask_functor, | ||
| output, | ||
| valid_length.value(), | ||
| scale, | ||
| batch_size, | ||
| max_prompt_len, | ||
| max_cache_valid_length_value, | ||
| Int(num_heads), | ||
| Int(depth), | ||
| Int(group), | ||
| ctx, | ||
| sink_weights, | ||
| ) | ||
| return |
…ttention sink kernel test
BEGIN_PUBLIC [Kernel][GPU] Restrict attention fallback to SM120 Narrow the sink and ragged attention correctness fallbacks to SM120-class GPUs instead of applying them to all non-SM90 and non-SM100 NVIDIA devices. END_PUBLIC Signed-off-by: Raif Olson <[email protected]>
Author
|
This is on the critic path to getting GPT-OSS working correctly on SM120. See #6216 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix ragged and decode attention behavior on SM120-class GPUs, including sink-aware softmax handling and fused RoPE correctness in the live attention path.
This keeps the changes focused on kernel correctness and the associated GPU test gating.
This fixes issue #6198
Testing
Ran:
./bazelw --batch test //max/kernels/test/gpu/nn:test_flash_attention.mojo.test --test_output=errorsChecklist
PR is small and focused — consider splitting larger changes into a
sequence of smaller PRs
I ran
./bazelw run formatto format my changesNote: flash_attention[sink=True] receives a non-owning view of the sink weights (sinks_device.get_immutable()), while the actual owner is sinks_dev. Because the launch is async, the test must keep sinks_dev alive until the GPU work and the output readback are complete; otherwise the sink buffer can be reclaimed too early and the test becomes flaky/incorrect. The final ctx.synchronize() is still needed separately to ensure the D2H copy into out_ptr has completed before reading out_host. We could change the API of the flash attention kernel, but I didn't want to do that yet.
I added or updated tests to cover my changes
If AI tools assisted with this contribution, I have included an
Assisted-by:trailer in my commit message or this PR description(see AI Tool Use Policy)
Assisted-by: Codex