-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updating raft::linalg
APIs to use mdspan
#809
Updating raft::linalg
APIs to use mdspan
#809
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done an initial pass at what you have so far. I haven't looked things through exhaustive- such as making sure each new function has an associated test, or even that all the functions have been mdspan
-ified. I'll do a couple more passes but wanted to provide some immediate thoughts/feedback in the meantime to help accelerate the process.
*/ | ||
|
||
/** | ||
* @brief Computes the sum-reduction of matrix columns for each given key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I haven't noticed how confusing this was before. If i was a user looking at a function called reduce_cols_by_key
, I'd be expecting that I'd pass my own reduction function in. However, if I were invoking such a function without a reduction function (where another overload were provided that accepts a lambda), I'd expect it to perform a sum by default. It doesn't look like we provide such an option, though. What makes things even more confusing to me is that the reduce_rows_by_key
function performs a weighted sum.
I'm not saying we need to rename or remove these functions, but it would be useful if we had a github issue (and a todo here in this file that references it) to allow generic reduction functions. That would make this much more useful and also allow it to be extended for generalized row-norm computations).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still waiting on todo / github issue here
*/ | ||
|
||
/** | ||
* @brief Computes the weighted reduction of matrix rows for each given key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above- can you put a todo in this file and reference a github issue to add an overload w/ generic reduction functions? I believe the weighted reduction here should be using a fused multiply-add intrinsic but a generalized reduction would make this primitive more widely applicable to different problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting me review! : - ) Some general comments:
- A read-only mdspan is
mdspan<const T, ...>
, notconst mdspan<T, ...>
. I can't think of a reason why one would want to useconst mdspan
as a function parameter type. mdspan::element_type
is the template argument that the user gave to mdspan.mdspan::value_type
is that with any cv-qualifiers removed.- Please see notes about how to check whether an mdspan is row-major or column-major.
- Please see notes about how to check whether 32-bit or 64-bit indices are needed for computations.
- Would you consider separating workspace queries from actual computations by using different function names?
- Would you consider making tolerances a function of the value type?
RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(), | ||
"Size mismatch between Output and Inputs"); | ||
|
||
if (out.size() <= std::numeric_limits<std::uint32_t>::max()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the goal here is to use 32-bit indices if possible, when inverting the layout mapping to use a 1-D loop index. This can be done, but there are two correctness issues with your approach.
-
The right quantity to test here is
out.required_span_size()
, notout.size()
. The layout mapping maps the input multidimensional index to the half-open interval of offsets [0,out.required_span_size()
). -
The
layout_{left, right, stride}::mapping
constructors generally have as a precondition that the required span size of the input extents (and strides, if applicable) be representable as a value of typeindex_type
.
Here is an approach that would address these issues.
template<class T>
constexpr bool is_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) == std::uint32_t;
template<class T>
constexpr bool is_greater_than_32_bit_integral_v = std::is_integral_v<T> && sizeof(T) > std::uint32_t;
if constexpr (is_32_bit_integral_v<typename OutType::index_type>) {
// ... always call 32-bit version ...
} else if constexpr (is_greater_than_32_bit_integral_v<typename OutType::index_type>) {
// ... test the value of `required_span_size()`; dispatch to 32-bit or index_type (64 or more bits) as needed ...
} else {
// ... always use index_type, which is 16 bits or less here ...
}
You'll also want to check the index_type
and required_span_size()
of the other mdspan. The above approach has the advantage that it only compiles an inner kernel for index types that you actually use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In point 2, what happens in extreme cases? Consider index_type=uint32_t
with extents {2^32, 2}
. In this case, will required_span_size()
by representable by index_type
or will it cause an overflow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyegala required_span_size()
is not representable by index_type
in this case. For layout_left
and layout_right
, required_span_size()
and size()
are the same mathematically. The only difference is the return type (index_type
resp. size_t
). For layout_stride
, though, required_span_size()
can be greater than the size()
. For other layouts (e.g., the "matrix of a single value" layout that maps all multidimensional indices to the offset zero), required_span_size()
can be less than size()
.
Note that while it's UB for users to violate preconditions, implementations aren't required to check preconditions. The reference implementation of layout_left
does not currently check preconditions, as you can see here, for instance. This means two things.
-
If someone gives you a
layout_{left,right,stride}::mapping
instance (e.g., in anmdspan
), then you can assume that the precondition is satisfied. -
If you are constructing a
layout_{left,right,stride}::mapping
instance (e.g., by constructing anmdspan
with a pointer and extents), then you are responsible for ensuring that the precondition is satisfied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyegala wrote:
In this case, will
required_span_size()
by representable byindex_type
or will it cause an overflow?
Those are two separate questions, actually! : - )
required_span_size()
is not representable byindex_type
in this case.- Giving this
extents
object tolayout_{left,right,stride}::mapping
's constructor violates the constructor's precondition. It could overflow, or it could open a portal to the Awesome Dimension and let loose a swarm of nasal demons who search out precondition violators and boop them gently on the nose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mhoemmen thanks for the explanations! How do we really represent such edge cases and safely obtain the product of the extents? Sounds like size()
is the safe way to obtain the product without violating any pre-conditions since it's representable by size_t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyegala Gladly! : - )
How do we really represent such edge cases and safely obtain the product of the extents?
By the time the user has created a layout mapping, it's already too late. What I mean by that is that if required_span_size()
doesn't fit index_type
, then the user will likely get the wrong answer when they try to index into the mdspan.
In what follows in my comment, I'll distinguish between "the Preconditions in the spec" and "what the reference implementation does." The reference implementation currently does not check this precondition in the layout mapping. This means that it's possible for users to construct extents for which the mapping's required_span_size()
can overflow.
We can prevent this by wrapping mdspan creation to check the extents object for potential overflow, before it goes into a layout mapping's constructor. It's not UB to construct, e.g., dextents<uint16_t, 2>(
2^{15} ,
2^{15} )
. We just need to intercept that naughty extents
value before it goes into a layout mapping's constructor. Otherwise, the layout mapping has the freedom to do whatever it likes, including calling abort()
.
Our mdarray implementation's conversion to mdspan can also check, but again, we're probably better off making the wrapper explicit and not part of the mdarray proposal. WG21 likes Preconditions and wants violating them to be UB. If we want some specified behavior (e.g., throwing a particular exception, or calling terminate()
after printing a helpful error message), then we'll have to implement that ourselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like out.required_span_size()
does not work. How do I access this from the layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Things are coming together great! I just did another deeper review. I'll try and do another before EOD.
* @param[in] main_op fused elementwise operation to apply before reduction | ||
* @param[in] reduce_op fused binary reduction operation | ||
* @param[in] final_op fused elementwise operation to apply before storing results | ||
* @param[in] inplace reduction result added inplace or overwrites old values? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me from reading this doc 1) whether or not the dots
output is needed when inplace==true
and 2) if data == dots
when inplace==true
. I felt the same while looking over the API for reduce
. This might just require some additional documentation in these functions, though I wonder if it's an indication that an update to the API might be needed (I'm just really not sure what that would be as of yet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear to me as well, and I mentioned it to another of @mhoemmen 's comments. Reading the impl did not make it any clearer for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should do a separate overload for the inplace version like we do in raft::matrix
. That makes it much more straightforward to the user that it's in place (and which array(s) end up being overwriten).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this is coming across the final bend. Really minor things at this point.
raft::device_matrix_view<const InValueType, IndexType, LayoutPolicy> data, | ||
raft::device_vector_view<OutValueType, IndexType> dots, | ||
OutValueType init, | ||
bool inplace = false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I realized as I was going through the raft::matrix
APIs is that they have an overload for the inplace version and I think that looks more clean because the arguments kind of speak for themselves (e.g no output argument and docs say inout
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I do the separate overload in this PR or wait for the move to raft::matrix
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also really don't think inplace
means what we think it means here. Looking at the kernel, it seems like all inplace=true
does is run the reduction_op
one more time over the output pointer.
* @param[in] main_op fused elementwise operation to apply before reduction | ||
* @param[in] reduce_op fused binary reduction operation | ||
* @param[in] final_op fused elementwise operation to apply before storing results | ||
* @param[in] inplace reduction result added inplace or overwrites old values? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should do a separate overload for the inplace version like we do in raft::matrix
. That makes it much more straightforward to the user that it's in place (and which array(s) end up being overwriten).
@gpucibot merge |
No description provided.