-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fused cosine 1-NN cutlass based kernel #2125
Conversation
/ok to test |
1 similar comment
/ok to test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh
Outdated
Show resolved
Hide resolved
6b548d3
to
d796a47
Compare
fix doc issue in fused_distance_nn runtime API
20cd99e
to
1417a2e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdoijade for this PR, it is great to have the fused cosine distance kernels to further accelerate k-means clustering!
It would be nice if we could further improve code reuse between L2
and Cosine
variants of fusedNN (decrease duplication of boiler plate code), but I see that we have some discrepancy in the required parameters, and this could negatively affect readibility.
Otherwise the PR looks good in general. Here are my comments
cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h
Outdated
Show resolved
Hide resolved
cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh
Outdated
Show resolved
Hide resolved
4d0ad4f
to
5384408
Compare
…ed_l2_nn_min_arg, support only float for fused_distance_nn
…ll supported distance metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdoijade for addressing the issues! The code looks good to me.
Please still have a look at the copyright start year in some of the new files.
the pylibraft test passes but an unrelated cpp test is failing |
/rerun tests |
/merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! one minor comment:
Co-authored-by: Ben Frederickson <[email protected]>
fused_distance_nn_arg_min
supporting cosine & L2 distance metrics.