Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for raft::matrix #937

Merged
merged 18 commits into from
Oct 27, 2022
Merged

Conversation

lowener
Copy link
Contributor

@lowener lowener commented Oct 21, 2022

Linking #877.
Implementation of the remaining tests for raft::matrix

@lowener lowener requested review from a team as code owners October 21, 2022 14:58
@cjnolet cjnolet added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Oct 21, 2022
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on @lowener! I've done a very quick skim through the code. I'll do a deeper next week but wanted to share my initial feedback in the meantime. So far I've seen mostly minor things.

cpp/include/raft/linalg/matrix_vector.cuh Show resolved Hide resolved
cpp/include/raft/matrix/argmax.cuh Outdated Show resolved Hide resolved
@@ -170,7 +169,7 @@ void printHost(const m_t* in, idx_t n_rows, idx_t n_cols)
*/
template <typename m_t, typename idx_t = int>
__global__ void slice(
m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2)
const m_t* src_d, idx_t m, idx_t n, m_t* dst_d, idx_t x1, idx_t y1, idx_t x2, idx_t y2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned that changing these is going to ultimately break the existing APIs. We could cast away the constness for now (and create a github issue for it) or test this PR in cuml to make sure it doesn't break anythign.

cpp/include/raft/matrix/norm.cuh Show resolved Hide resolved
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, @lowener! The raft::matrix APIs are really coming together and I'm so happy to see tests for all these prims. Mostly very minor things again.

void binary_mult_skip_zero(const raft::handle_t& handle,
raft::device_matrix_view<math_t, idx_t, layout_t> data,
raft::device_vector_view<const math_t, idx_t> vec,
Apply apply)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of making this an enum but I suggest we find a more descriptive name for this. Apply just brings to mind several different possibilities, none of which immediately said "oh broadcast the vector across the rows" to me.

Maybe something like BroadcastType? (That's just off the top of my head, im sure there are better names)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe even VectorBroadcast::ALONG_ROWS?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized this was already defined in linalg (prior to this PR). I think we can keep this name for now and rename this in a future PR.

cpp/include/raft/matrix/reverse.cuh Outdated Show resolved Hide resolved
cpp/include/raft/matrix/slice.cuh Outdated Show resolved Hide resolved
template <typename math_t, typename extents, typename layout>
void fill(const raft::handle_t& handle,
raft::device_mdspan<math_t, extents, layout> inout,
math_t scalar)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a host scalar view as well to match the other overload.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was intended to be a much easier version to use because scalar is not used as a "out parameter". I don't see the need for it to be host scalar view.
Otherwise we'd have to create a host scalar view that holds a pointer to zero every time we want to reset to zero a matrix?
And the in parameter seems to be useless in the other overload as well when looking at the implementation details.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually thinking we can probably hold off on adding the host scalar view version altogether for now until we have a need to support the device scalar view. These are more generally useful when we have a pointer to device memory that we want to use as input but we don't want to synchronize the stream each time.

cpp/include/raft/matrix/detail/matrix.cuh Show resolved Hide resolved
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for writingt these tests, @lowener!

@cjnolet
Copy link
Member

cjnolet commented Oct 27, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit b0a7064 into rapidsai:branch-22.12 Oct 27, 2022
@lowener lowener deleted the 22.12-matrix-test branch October 28, 2022 16:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

2 participants