Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating raft::linalg APIs to use mdspan #809

Merged

Conversation

divyegala
Copy link
Member

No description provided.

@divyegala divyegala requested a review from a team as a code owner September 7, 2022 17:02
@github-actions github-actions bot added the cpp label Sep 7, 2022
@divyegala divyegala added feature request New feature or request non-breaking Non-breaking change 2 - In Progress Currenty a work in progress labels Sep 7, 2022
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done an initial pass at what you have so far. I haven't looked things through exhaustive- such as making sure each new function has an associated test, or even that all the functions have been mdspan-ified. I'll do a couple more passes but wanted to provide some immediate thoughts/feedback in the meantime to help accelerate the process.

cpp/include/raft/linalg/axpy.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/axpy.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/binary_op.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/cholesky_r1_update.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/cholesky_r1_update.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/sqrt.cuh Outdated Show resolved Hide resolved
*/

/**
* @brief Computes the sum-reduction of matrix columns for each given key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I haven't noticed how confusing this was before. If i was a user looking at a function called reduce_cols_by_key, I'd be expecting that I'd pass my own reduction function in. However, if I were invoking such a function without a reduction function (where another overload were provided that accepts a lambda), I'd expect it to perform a sum by default. It doesn't look like we provide such an option, though. What makes things even more confusing to me is that the reduce_rows_by_key function performs a weighted sum.

I'm not saying we need to rename or remove these functions, but it would be useful if we had a github issue (and a todo here in this file that references it) to allow generic reduction functions. That would make this much more useful and also allow it to be extended for generalized row-norm computations).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still waiting on todo / github issue here

*/

/**
* @brief Computes the weighted reduction of matrix rows for each given key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above- can you put a todo in this file and reference a github issue to add an overload w/ generic reduction functions? I believe the weighted reduction here should be using a fused multiply-add intrinsic but a generalized reduction would make this primitive more widely applicable to different problems.

cpp/include/raft/linalg/strided_reduction.cuh Show resolved Hide resolved
cpp/include/raft/linalg/strided_reduction.cuh Show resolved Hide resolved
Copy link
Contributor

@mhoemmen mhoemmen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me review! : - ) Some general comments:

  1. A read-only mdspan is mdspan<const T, ...>, not const mdspan<T, ...>. I can't think of a reason why one would want to use const mdspan as a function parameter type.
  2. mdspan::element_type is the template argument that the user gave to mdspan. mdspan::value_type is that with any cv-qualifiers removed.
  3. Please see notes about how to check whether an mdspan is row-major or column-major.
  4. Please see notes about how to check whether 32-bit or 64-bit indices are needed for computations.
  5. Would you consider separating workspace queries from actual computations by using different function names?
  6. Would you consider making tolerances a function of the value type?

cpp/include/raft/linalg/add.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/add.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/add.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/add.cuh Outdated Show resolved Hide resolved
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(),
"Size mismatch between Output and Inputs");

if (out.size() <= std::numeric_limits<std::uint32_t>::max()) {
Copy link
Contributor

@mhoemmen mhoemmen Sep 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the goal here is to use 32-bit indices if possible, when inverting the layout mapping to use a 1-D loop index. This can be done, but there are two correctness issues with your approach.

  1. The right quantity to test here is out.required_span_size(), not out.size(). The layout mapping maps the input multidimensional index to the half-open interval of offsets [0, out.required_span_size()).

  2. The layout_{left, right, stride}::mapping constructors generally have as a precondition that the required span size of the input extents (and strides, if applicable) be representable as a value of type index_type.

Here is an approach that would address these issues.

template<class T>
constexpr bool is_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) == std::uint32_t;
template<class T>
constexpr bool is_greater_than_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) > std::uint32_t;

if constexpr (is_32_bit_integral_v<typename OutType::index_type>) {
  // ... always call 32-bit version ...
} else if constexpr (is_greater_than_32_bit_integral_v<typename OutType::index_type>) {
  // ... test the value of `required_span_size()`; dispatch to 32-bit or index_type (64 or more bits) as needed ...
} else {
  // ... always use index_type, which is 16 bits or less here ...
}

You'll also want to check the index_type and required_span_size() of the other mdspan. The above approach has the advantage that it only compiles an inner kernel for index types that you actually use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In point 2, what happens in extreme cases? Consider index_type=uint32_t with extents {2^32, 2}. In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala required_span_size() is not representable by index_type in this case. For layout_left and layout_right, required_span_size() and size() are the same mathematically. The only difference is the return type (index_type resp. size_t). For layout_stride, though, required_span_size() can be greater than the size(). For other layouts (e.g., the "matrix of a single value" layout that maps all multidimensional indices to the offset zero), required_span_size() can be less than size().

Note that while it's UB for users to violate preconditions, implementations aren't required to check preconditions. The reference implementation of layout_left does not currently check preconditions, as you can see here, for instance. This means two things.

  1. If someone gives you a layout_{left,right,stride}::mapping instance (e.g., in an mdspan), then you can assume that the precondition is satisfied.

  2. If you are constructing a layout_{left,right,stride}::mapping instance (e.g., by constructing an mdspan with a pointer and extents), then you are responsible for ensuring that the precondition is satisfied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala wrote:

In this case, will required_span_size() by representable by index_type or will it cause an overflow?

Those are two separate questions, actually! : - )

  1. required_span_size() is not representable by index_type in this case.
  2. Giving this extents object to layout_{left,right,stride}::mapping's constructor violates the constructor's precondition. It could overflow, or it could open a portal to the Awesome Dimension and let loose a swarm of nasal demons who search out precondition violators and boop them gently on the nose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhoemmen thanks for the explanations! How do we really represent such edge cases and safely obtain the product of the extents? Sounds like size() is the safe way to obtain the product without violating any pre-conditions since it's representable by size_t?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyegala Gladly! : - )

How do we really represent such edge cases and safely obtain the product of the extents?

By the time the user has created a layout mapping, it's already too late. What I mean by that is that if required_span_size() doesn't fit index_type, then the user will likely get the wrong answer when they try to index into the mdspan.

In what follows in my comment, I'll distinguish between "the Preconditions in the spec" and "what the reference implementation does." The reference implementation currently does not check this precondition in the layout mapping. This means that it's possible for users to construct extents for which the mapping's required_span_size() can overflow.

We can prevent this by wrapping mdspan creation to check the extents object for potential overflow, before it goes into a layout mapping's constructor. It's not UB to construct, e.g., dextents<uint16_t, 2>( 2^{15} , 2^{15} ). We just need to intercept that naughty extents value before it goes into a layout mapping's constructor. Otherwise, the layout mapping has the freedom to do whatever it likes, including calling abort().

Our mdarray implementation's conversion to mdspan can also check, but again, we're probably better off making the wrapper explicit and not part of the mdarray proposal. WG21 likes Preconditions and wants violating them to be UB. If we want some specified behavior (e.g., throwing a particular exception, or calling terminate() after printing a helpful error message), then we'll have to implement that ourselves.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like out.required_span_size() does not work. How do I access this from the layout?

cpp/include/raft/linalg/matrix_vector_op.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/matrix_vector_op.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/power.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/power.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/power.cuh Outdated Show resolved Hide resolved
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things are coming together great! I just did another deeper review. I'll try and do another before EOD.

cpp/include/raft/linalg/cholesky_r1_update.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/axpy.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/axpy.cuh Show resolved Hide resolved
cpp/include/raft/linalg/cholesky_r1_update.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/eig.cuh Show resolved Hide resolved
* @param[in] main_op fused elementwise operation to apply before reduction
* @param[in] reduce_op fused binary reduction operation
* @param[in] final_op fused elementwise operation to apply before storing results
* @param[in] inplace reduction result added inplace or overwrites old values?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me from reading this doc 1) whether or not the dots output is needed when inplace==true and 2) if data == dots when inplace==true. I felt the same while looking over the API for reduce. This might just require some additional documentation in these functions, though I wonder if it's an indication that an update to the API might be needed (I'm just really not sure what that would be as of yet).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear to me as well, and I mentioned it to another of @mhoemmen 's comments. Reading the impl did not make it any clearer for me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do a separate overload for the inplace version like we do in raft::matrix. That makes it much more straightforward to the user that it's in place (and which array(s) end up being overwriten).

cpp/include/raft/linalg/rsvd.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/rsvd.cuh Show resolved Hide resolved
cpp/include/raft/linalg/unary_op.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/unary_op.cuh Show resolved Hide resolved
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this is coming across the final bend. Really minor things at this point.

cpp/include/raft/linalg/add.cuh Show resolved Hide resolved
cpp/include/raft/linalg/binary_op.cuh Show resolved Hide resolved
raft::device_matrix_view<const InValueType, IndexType, LayoutPolicy> data,
raft::device_vector_view<OutValueType, IndexType> dots,
OutValueType init,
bool inplace = false,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I realized as I was going through the raft::matrix APIs is that they have an overload for the inplace version and I think that looks more clean because the arguments kind of speak for themselves (e.g no output argument and docs say inout).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I do the separate overload in this PR or wait for the move to raft::matrix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also really don't think inplace means what we think it means here. Looking at the kernel, it seems like all inplace=true does is run the reduction_op one more time over the output pointer.

cpp/include/raft/linalg/gemm.cuh Show resolved Hide resolved
* @param[in] main_op fused elementwise operation to apply before reduction
* @param[in] reduce_op fused binary reduction operation
* @param[in] final_op fused elementwise operation to apply before storing results
* @param[in] inplace reduction result added inplace or overwrites old values?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should do a separate overload for the inplace version like we do in raft::matrix. That makes it much more straightforward to the user that it's in place (and which array(s) end up being overwriten).

cpp/include/raft/linalg/strided_reduction.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/subtract.cuh Show resolved Hide resolved
cpp/include/raft/linalg/subtract.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/svd.cuh Outdated Show resolved Hide resolved
@cjnolet
Copy link
Member

cjnolet commented Sep 30, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 7adf15e into rapidsai:branch-22.10 Sep 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
2 - In Progress Currenty a work in progress cpp feature request New feature or request non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants