-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Fix matmul_kernel #5701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix matmul_kernel #5701
Conversation
…ar into matmul-kernel-fix
|
I have read the CLA Document and I hereby sign the CLA You can retrigger this bot by commenting recheck in this Pull Request. Posted by the CLA Assistant Lite bot. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR aims to fix the matmul_kernel to correctly handle matrix dimensions that are not divisible by the tile size. The changes update the tiling logic to properly handle partial tiles when K is not a multiple of tile_size.
Key changes:
- Modified kernel logic to compute
current_tile_sizefor partial tiles instead of using a pre-computed remainder - Parameterized test functions to accept (M, N, K) dimensions and test multiple size combinations
- Added a new test file
test_matmul_kernel.mojoto validate the kernel implementation
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| max/kernels/test/gpu/linalg/test_matmul_kernel.mojo | New test file that validates matmul_kernel with non-divisible matrix dimensions (513, 502, 511) |
| max/kernels/test/gpu/examples/test_matmul_1_sram.mojo | Updated to parameterize test with compile-time M, N, K dimensions and fixed partial tile handling |
| max/kernels/src/linalg/matmul/gpu/init.mojo | Fixed matmul_kernel to properly handle partial tiles by computing tile size dynamically |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
| else: | ||
| a_val = ( | ||
| a_shared[localRow * UInt(current_tile_size) + localCol] = ( |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect. The shared memory is allocated with fixed dimensions of tile_size * tile_size, so the row stride should always be tile_size, not current_tile_size. Using current_tile_size as the stride will cause incorrect memory layout when handling partial tiles. The indexing should be a_shared[localRow * UInt(tile_size) + localCol].
| a[Int(row), offset + Int(localCol)] if ( | ||
| row < UInt(M) | ||
| ) else 0.0 | ||
| ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When handling partial tiles, if the condition offset + Int(localCol) < K is false, no value is written to the corresponding location in a_shared. However, this location may still be read during the computation loop (line 120), potentially using uninitialized or stale data. To ensure correctness, a zero value should be written when the condition is false, similar to how the full_tile case handles out-of-bounds accesses.
| ) | |
| ) | |
| else: | |
| # For out-of-bounds columns in a partial tile, write zero to avoid | |
| # reading uninitialized or stale data during the compute loop. | |
| a_shared[localRow * UInt(current_tile_size) + localCol] = 0.0 |
| for kk in range(current_tile_size): | ||
| result += ( | ||
| a_shared[localRow * UInt(tile_size) + UInt(kk)].cast[s_type]() | ||
| a_shared[localRow * UInt(current_tile_size) + UInt(kk)].cast[s_type]() |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect when reading during computation. The shared memory is allocated with fixed dimensions of tile_size * tile_size, so the row stride should always be tile_size, not current_tile_size. The indexing should be a_shared[localRow * UInt(tile_size) + UInt(kk)].
| for k in range(current_tile_size): | ||
| result += a_shared.load( | ||
| localRow * UInt(tile_size) + UInt(k) | ||
| localRow * UInt(current_tile_size) + UInt(k) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect when reading during computation. The shared memory is allocated with fixed dimensions of tile_size * tile_size, so the row stride should always be tile_size, not current_tile_size. The indexing should be a_shared.load(localRow * UInt(tile_size) + UInt(k)).
| for i in range(M - 10, M): | ||
| for j in range(N - 10, N): | ||
| if c_host[i, j] != Float32(K): | ||
| print( | ||
| "Fail at index = [", | ||
| i, | ||
| ",", | ||
| j, | ||
| "] the value is", | ||
| c_host[i, j], | ||
| "the golden value is", | ||
| K, | ||
| ) | ||
| failed = True |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test only validates the last 10x10 block of the result matrix (lines 82-95), which provides inadequate test coverage. Bugs in the kernel could go undetected if they only affect elements outside this region. The test should validate all elements of the result matrix to ensure the matmul kernel works correctly across the entire output.
| a[Int(row), Int(offset + Int(localCol))] if ( | ||
| row < UInt(m) | ||
| ) else 0.0 | ||
| ) |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When handling partial tiles, if the condition offset + Int(localCol) < k is false, no value is written to the corresponding location in a_shared. However, this location may still be read during the computation loop (line 166), potentially using uninitialized or stale data. To ensure correctness, a zero value should be written when the condition is false, similar to how the full_tile case handles out-of-bounds accesses.
| ) | |
| ) | |
| else: | |
| a_shared[localRow * UInt(current_tile_size) + localCol] = 0.0 |
| row < UInt(m) and offset + Int(localCol) < k | ||
| ) else 0.0 | ||
| if offset + Int(localCol) < k: | ||
| a_shared[localRow * UInt(current_tile_size) + localCol] = ( |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect. The shared memory is allocated with fixed dimensions of tile_size * tile_size (as seen in the allocation), so the row stride should always be tile_size, not current_tile_size. Using current_tile_size as the stride will cause incorrect memory layout when handling partial tiles. The indexing should be a_shared[localRow * UInt(tile_size) + localCol].
| a_val = ( | ||
| a[Int(row), Int(offset + Int(localCol))] if row | ||
| < UInt(m) else 0.0 | ||
| a_shared[localRow * UInt(current_tile_size) + localCol] = ( |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect. The shared memory is allocated with fixed dimensions of tile_size * tile_size, so the row stride should always be tile_size, not current_tile_size. Using current_tile_size as the stride will cause incorrect memory layout when handling partial tiles. The indexing should be a_shared[localRow * UInt(tile_size) + localCol].
| row < UInt(M) and offset + Int(localCol) < K | ||
| ) else 0.0 | ||
| if offset + Int(localCol) < K: | ||
| a_shared[localRow * UInt(current_tile_size) + localCol] = ( |
Copilot
AI
Dec 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared memory indexing for matrix A is incorrect. The shared memory is allocated with fixed dimensions of tile_size * tile_size (lines 54-58), so the row stride should always be tile_size, not current_tile_size. Using current_tile_size as the stride will cause incorrect memory layout when handling partial tiles. The indexing should be a_shared[localRow * UInt(tile_size) + localCol].
#5700
test_matmul_1_sram.mojoso that (M, N, K)-s of all relative sizes are tested.test_matmul_1_sram.mojo.matmul_kernelsimilar to that oftest_matmul_1_sram.mojo.matmul_kernel.