Skip to content

[Kernels] Sm120 attention correctness#6209

Open
rolson24 wants to merge 8 commits intomodular:mainfrom
rolson24:sm120-attention-correctness
Open

[Kernels] Sm120 attention correctness#6209
rolson24 wants to merge 8 commits intomodular:mainfrom
rolson24:sm120-attention-correctness

Conversation

@rolson24
Copy link

@rolson24 rolson24 commented Mar 19, 2026

Summary

Fix ragged and decode attention behavior on SM120-class GPUs, including sink-aware softmax handling and fused RoPE correctness in the live attention path.

This keeps the changes focused on kernel correctness and the associated GPU test gating.

This fixes issue #6198

Testing

Ran:
./bazelw --batch test //max/kernels/test/gpu/nn:test_flash_attention.mojo.test --test_output=errors

Checklist

  • PR is small and focused — consider splitting larger changes into a
    sequence of smaller PRs

  • I ran ./bazelw run format to format my changes
    Note: flash_attention[sink=True] receives a non-owning view of the sink weights (sinks_device.get_immutable()), while the actual owner is sinks_dev. Because the launch is async, the test must keep sinks_dev alive until the GPU work and the output readback are complete; otherwise the sink buffer can be reclaimed too early and the test becomes flaky/incorrect. The final ctx.synchronize() is still needed separately to ensure the D2H copy into out_ptr has completed before reading out_host. We could change the API of the flash attention kernel, but I didn't want to do that yet.

  • I added or updated tests to cover my changes

  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description
    (see AI Tool Use Policy)

Assisted-by: Codex

BEGIN_PUBLIC
[Kernel][GPU] Fix SM120 GPT-OSS attention correctness

Fix ragged and decode attention behavior on SM120-class GPUs, including
sink-aware softmax handling and fused RoPE correctness in the live attention
path.

This keeps the changes focused on kernel correctness and the associated GPU
test gating.
END_PUBLIC

Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC
[Kernel][GPU] Fix sink attention test buffer lifetime

Keep the sink-weight device buffer alive across the async flash attention
launch and readback in the sink test, and allow the sink path to execute on
SM120 so the test covers that configuration.
END_PUBLIC

Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC
[Kernel][GPU] Tighten SM120 sink test changes

Restore the SM100-only assertion on the specialized prefill path and keep
only the minimal sink-buffer lifetime fix in the sink attention test.
END_PUBLIC

Signed-off-by: Raif Olson <[email protected]>
BEGIN_PUBLIC
[Kernel][GPU] Split RoPE changes from attention branch

Restore the RoPE kernels to main on the SM120 attention branch so the
flash attention correctness work can be reviewed independently.
END_PUBLIC

Signed-off-by: Raif Olson <[email protected]>
@rolson24 rolson24 requested a review from a team as a code owner March 19, 2026 14:08
Copilot AI review requested due to automatic review settings March 19, 2026 14:08
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes flash-attention correctness issues on SM120-class NVIDIA GPUs, focusing on sink-aware softmax normalization, ragged/decode behavior, and MMA online-softmax state handling, plus a test stabilization tweak.

Changes:

  • Correct sink-weight indexing and sink normalization in the GPU softmax kernel.
  • Fix online-softmax split-warp reduction by initializing and writing back rowmax/rowsum state.
  • Add NVIDIA-device fallbacks to the naive MHA path for sink/ragged cases and tighten test synchronization/lifetime behavior.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
max/kernels/test/gpu/nn/test_flash_attention.mojo Adds extra synchronization and an explicit lifetime “use” to prevent sink buffer early reclamation during async execution.
max/kernels/src/nn/softmax.mojo Fixes sink-weight addressing (per head rather than flattened row) and corrects sink contribution to the softmax denominator; initializes online-softmax reduction state.
max/kernels/src/nn/mha.mojo Introduces SM120-related correctness fallbacks to mha_gpu_naive, threads sink support through ragged entry points, and updates sink-weight loads / softmax call typing.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 549 to +580
if not is_token_generation:
# Correctness fallback for sink-aware prefill on affected NVIDIA
# paths. The shared kernel is currently diverging from the
# reference sink normalization semantics on SM120-class devices.
comptime if (
has_nvidia_gpu_accelerator()
and sink
and not is_sm90
and not is_sm100
):
mha_gpu_naive[
ragged=ragged,
sink=sink,
_use_valid_length=_use_valid_length,
_is_cache_length_accurate=_is_cache_length_accurate,
](
q,
k,
v,
mask_functor,
output,
valid_length.value(),
scale,
batch_size,
max_prompt_len,
max_cache_valid_length,
Int(num_heads),
Int(depth),
Int(group),
ctx,
sink_weights,
)
Comment on lines +787 to +818
# Correctness fallback for ragged decoding on affected NVIDIA
# paths. The live ragged/paged kernel output diverges from both the
# explicit score reconstruction and the CPU reference on SM120.
comptime if (
has_nvidia_gpu_accelerator()
and (ragged or sink)
and not is_sm90
and not is_sm100
):
mha_gpu_naive[
ragged=ragged,
sink=sink,
_use_valid_length=_use_valid_length,
_is_cache_length_accurate=_is_cache_length_accurate,
](
q,
k,
v,
mask_functor,
output,
valid_length.value(),
scale,
batch_size,
max_prompt_len,
max_cache_valid_length_value,
Int(num_heads),
Int(depth),
Int(group),
ctx,
sink_weights,
)
return
BEGIN_PUBLIC
[Kernel][GPU] Restrict attention fallback to SM120

Narrow the sink and ragged attention correctness fallbacks to SM120-class
GPUs instead of applying them to all non-SM90 and non-SM100 NVIDIA devices.
END_PUBLIC

Signed-off-by: Raif Olson <[email protected]>
@rolson24 rolson24 changed the title Sm120 attention correctness [Kernels] Sm120 attention correctness Mar 19, 2026
@rolson24
Copy link
Author

This is on the critic path to getting GPT-OSS working correctly on SM120. See #6216

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants