-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tests for raft::matrix
#937
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on @lowener! I've done a very quick skim through the code. I'll do a deeper next week but wanted to share my initial feedback in the meantime. So far I've seen mostly minor things.
@@ -170,7 +169,7 @@ void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) | |||
*/ | |||
template <typename m_t, typename idx_t = int> | |||
__global__ void slice( | |||
m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) | |||
const m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little concerned that changing these is going to ultimately break the existing APIs. We could cast away the constness for now (and create a github issue for it) or test this PR in cuml to make sure it doesn't break anythign.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great, @lowener! The raft::matrix APIs are really coming together and I'm so happy to see tests for all these prims. Mostly very minor things again.
void binary_mult_skip_zero(const raft::handle_t& handle, | ||
raft::device_matrix_view<math_t, idx_t, layout_t> data, | ||
raft::device_vector_view<const math_t, idx_t> vec, | ||
Apply apply) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of making this an enum but I suggest we find a more descriptive name for this. Apply
just brings to mind several different possibilities, none of which immediately said "oh broadcast the vector across the rows" to me.
Maybe something like BroadcastType
? (That's just off the top of my head, im sure there are better names)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe even VectorBroadcast::ALONG_ROWS
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized this was already defined in linalg
(prior to this PR). I think we can keep this name for now and rename this in a future PR.
template <typename math_t, typename extents, typename layout> | ||
void fill(const raft::handle_t& handle, | ||
raft::device_mdspan<math_t, extents, layout> inout, | ||
math_t scalar) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a host scalar view as well to match the other overload.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was intended to be a much easier version to use because scalar
is not used as a "out parameter". I don't see the need for it to be host scalar view.
Otherwise we'd have to create a host scalar view that holds a pointer to zero every time we want to reset to zero a matrix?
And the in
parameter seems to be useless in the other overload as well when looking at the implementation details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually thinking we can probably hold off on adding the host scalar view version altogether for now until we have a need to support the device scalar view. These are more generally useful when we have a pointer to device memory that we want to use as input but we don't want to synchronize the stream each time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for writingt these tests, @lowener!
@gpucibot merge |
Linking #877.
Implementation of the remaining tests for
raft::matrix