-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067) #2067
Conversation
/ok to test |
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
/ok to test |
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great - thanks for the PR!
/ok to test |
1 similar comment
/ok to test |
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
13b6a45
to
61d9558
Compare
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
typename NZType, | ||
typename LayoutPolicyA, | ||
typename LayoutPolicyB> | ||
void sddmm(raft::resources const& handle, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the RAFT API functions, the order of parameters should be:
- Handle
- Input view
- Output view
- Extra parameter (
alpha
,beta
,trans_a
,trans_b
)
(Even though spmm
currently doesn't do that)
And using raft::host_scalar_view
should be a good idea instead of a raw pointer for alpha
and beta
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense! But considering the minimum of the customer's learning cost, may I suggest here to keep the params in a similar sequence with the cuSparse
original API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RAFT APIs don't assume the user is familiar with, nor even knows that cusparse is being used under the hood, so it's best to keep up with the conventions established by RAFT so that the user has a consistent experience.
@lowener is correct to point this out- we had a lot of discussions about this when the convention was established and we strive to use the same conventions everywhere.
(Resources, param structs, in, out, params)
Our APIs are also intentionally based on mdspan so we shouldn't accept pointers anywhere. All of the pointer-based APIs that are exposed publicly are deprecated.
If alpha and beta can be specified on device or host, we should capture this with mdspans. This also makes the APIs self documenting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it.
typename IndexType, | ||
typename LayoutPolicyA, | ||
typename LayoutPolicyB> | ||
bool is_row_major(raft::device_matrix_view<ValueTypeA, IndexType, LayoutPolicyA>& a, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be constexpr
. Can it also reuse raft::util::is_row_major
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept, and I removed it because the new implement doesn't need the API anymore.
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
523f4f7
to
cdae2b5
Compare
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
@@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; | |||
*/ | |||
enum class FillMode { UPPER, LOWER }; | |||
|
|||
/** | |||
* @brief Enum for this type indicates which operations is applied to the related input (e.g. sparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operation
*/ | ||
template <typename ValueType, typename IndexType, typename LayoutPolicy> | ||
cusparseDnMatDescr_t create_descriptor( | ||
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Views/mdspans are usually passed by value and not by reference because they are lightweight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_row_major
is not necessary here because this information can be inferred from the layout of the matrix view. Call raft::is_row_major()
inside this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept, just inherit from the spmm.
* @brief convert the operation to cusparseOperation_t type | ||
* @tparam OpVal type of operation | ||
*/ | ||
static inline cusparseOperation_t convert_operation(const raft::linalg::Operation& op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function need to be static? The reference is not needed as well, this can be passed by value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept!
|
||
/** | ||
* @brief convert the operation to cusparseOperation_t type | ||
* @tparam OpVal type of operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update this comment: param[in] op type of operation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
@rhdong Why so many force pushes to the branch? You should be able to merge upstream into your branch cleanly. The commits are squashed automatically upon merging the PR so there's no reason to rewrite history. |
OK, get it~(just want to make the PR as one commit and keep the history clear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great so far! I mostly have minor things (in addition to automating the expected test data and contributing the benchmarks that you've put a lot of effort into).
&bufferSize, | ||
resource::get_cuda_stream(handle))); | ||
|
||
raft::interruptible::synchronize(resource::get_cuda_stream(handle)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to centralize the interruptible calls instead of calling it directly, please use resource::sync_stream()
instead.
@@ -0,0 +1,103 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to avoid confusion, can we rename this file to cusparse_utils.hpp
? (Please note this shouldn't be a cuh
because it's not creating any device functions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept
@@ -0,0 +1,83 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this file doesn't create or use any device functions, please rename to sddmm.hpp
. This is a great designation to users that it only reuires the CUDA runtime APIs, math libs, and nothing else.
@@ -19,6 +19,7 @@ | |||
#pragma once | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should probaly be hpp
as well. Up to you whether you want to rename it in this PR (since it's already quite big).
cpp/test/sparse/sddmm.cu
Outdated
@@ -0,0 +1,425 @@ | |||
/* | |||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New file should only have current year
@@ -0,0 +1,83 @@ | |||
/* | |||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New files should only have current year.
@@ -0,0 +1,103 @@ | |||
/* | |||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New files should only have current year
@@ -0,0 +1,99 @@ | |||
/* | |||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New files should only have current year
cpp/test/sparse/sddmm.cu
Outdated
4, | ||
4, | ||
3, | ||
1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few places where we do this because it's really hard to generate the outputs for automating the comparisons. The problem with hardcoding these things is that it's non-trivial to add new test cases and so when a user comes comes to us with an issue where they are seeing strange results with some sparsities or specific inputs, we couldn't otherwise be able to quickly reproduce in a test once we fix the issue.
I think it's easy enough to automate theexpected
outputs, though- since you are already creating a mask, just copy out the code from your micro benchmarks that creates a random mask, generate two input dense arrays (using raft::random::make_blobs), compute the pairwise distances between A and B, and then copy in the rows/cols of the results to your "expected output" parse structure.
Given how much work went into benchmarking, I'd also highly suggest we commit your benchmarking code with these changes also. Since it was already written, I can't express enough how convenient it is to be able to load up a simple microbenchmark when someone finds a specific case to test. I understand this version is a lightweight wrapper around the SDDMM from cusparse, but that's likely not always going to be the case and so automated tests and benchmarks helps us evolve the code over time (and back-ends).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense!
ValueType(1.0f), | ||
uint64_t(2024)); | ||
|
||
raft::copy(a_data_h.data(), blobs_a_b.data_handle(), a_size, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new tests look great, thanks!
@@ -1,5 +1,5 @@ | |||
/* | |||
* Copyright (c) 2023, NVIDIA CORPORATION. | |||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a publicly facing header, we should add a new file called spmm.cuh
and have it import this header (with an include guard, of course), so that it doesn't break users downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense.
) - Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: rapidsai#2067
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @rhdong!
/merge |
cusparseSDDMM
SpMM
to theutils.cuh
file.