-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HOTFIX] Fix distance metrics L2/cosine/correlation when X & Y are same buffer but with different shape and add unit test for such case. #1571
Conversation
…same buffer but with different shape, also add unit test support for them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good, just hoping to consolidate the conditional to make it easier to read.
Thanks for finding and fixing this issue. It's a super subtle detail.
@@ -137,7 +137,7 @@ void distance_impl(raft::resources const& handle, | |||
AccT* y_norm = workspace; | |||
AccT* sq_x_norm = workspace; | |||
AccT* sq_y_norm = workspace; | |||
if (x != y) { | |||
if ((x != y) || ((x == y) && (m != n))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conditional is getting a little complicated and I see it in multiple places in this PR. Can we create a helper function that can give this a name and be reused throughout the code? It would make it a lot easier to read and maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have removed this conditional.
@mdoijade do you still think we can get this fix into 23.06? It looks like there's some c++ test failures |
I believe the failure is related to fp arithmetic accuracy issues in L2 distances for X == Y input cases. I don't see a bug as such I will fix the tolerance accordingly and push it soon. |
…ol until reduction accuracy issue is resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdoijade for the fix, it looks good to me.
We shall return to this in follow up pr, to fix potential problems at two more distance types:
if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } raft::linalg::unaryOp<DataT, decltype(unaryOp_lambda), IdxT>(
-- This is how tiled_brute_force_knn may use pairwise distance API hence assuming when X == Y the buffer has same shape is incorrect.