-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose linalg::dot
in public API
#968
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for this PR! We are going to want to think about whether these (axpy, dot, etc...) should be accepting general mdspan or whether we should be constraining them to be vectors up front.
It would also be nice to see the current vector factory functions made more flexible to enable strided layouts rather than adding new functions.
These types of examples (using the existing device vector factory functions to create a strided vector) would be great to have in the quick start as well.
*/ | ||
template <typename ElementType, typename IndexType = int, typename LayoutPolicy = layout_stride> | ||
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than adding another factory function for a strided vector, why not just allow a strided layout to be configured in the make_device_vector_view and make_host_vector_view?
Right now the make_*_vector_view automatically configures a row-major layout but the layout should really be configurable (and potentially strided, or col major if desired).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated make_device_vector_view to allow strided input here - let me know what you think.
cpp/include/raft/linalg/dot.cuh
Outdated
template <typename InputType1, | ||
typename InputType2, | ||
typename OutputType, | ||
typename = raft::enable_if_input_device_mdspan<InputType1>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I brought this up with the axpy as well, but it seems weird to accept a general mdspan for this when what we are really looking for is a 1d vector. Do you see value in accepting a matrix or dense tensor with 3+ dimensional extents? If not, we should just accept the vector_view directly (which is aliased to be any mdspan with 1d extents.
If we accepted a device_vector_view directly, we wouldn't need the enable_if statements at all. I think we should go ahead and do the same for the axpy to keep things consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed - made the changes here so that both axpy and dot take device_vector_view's
cpp/include/raft/linalg/dot.cuh
Outdated
|
||
// Right now the inputs and outputs need to all have the same value_type (float/double etc). | ||
// Try to output a meaningful compiler error if mismatched types are passed here. | ||
// Note: In the future we could remove this restriction using the cublasDotEx function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just go ahead and wrap the cublasEx functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created an issue so we can discuss further #977 .
Reading the docs a little closer, and it looks like even w/ cublasDotEx having different dtypes for the input/outputs isn't currently supported: https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx - so it won't have much value for the dot API (though I could see a use for it myself with the gemm api w/ implicit and the mixed precision work I was talking about last week)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes are looking great! Remaining things are very minor.
* @return raft::device_vector_view | ||
*/ | ||
template <typename ElementType, | ||
typename IndexType = std::uint32_t, | ||
typename LayoutPolicy = layout_c_contiguous> | ||
auto make_device_vector_view(ElementType* ptr, IndexType n) | ||
auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a little awkward. We accept a layout policy as a template argument, but then we also accept a function argument for a stride which essentially overrides the layout from the template.
Would it be achieving this same goal if a user were to just set a strided layout on the template argument directly? Perhaps we could provide a factory function to make said strided layout and provide the user with something like a statically sized object (eg. std::array) to set the strides for each dimension?
An of course, this is one of those things (the new strided factory function) that I think should have a usage example in the doxygen and perhaps even a subsection section in the mdspan tutorial markdown of the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding you correctly - you're thinking we can just pass the layout mapping to the make_device_vector_view
function directly , and add a new factory function for creating this layout mapping?
I took a stab at that in the last commit - unfortunately, I couldn't get a single make_device_vector_view
function to compile successfully with being passed both a IndexType with the number of elements and the Mapping with the strided layout (was getting compile errors in various other raft functions that I hadn't updated). However, I could get it to work with adding an overload - which is whats in the last commit. Do you have any suggestions on how to clean this up =) ?
I'll add something to the tutorial / docs once we're happy with the API -
* Remove default types, * Try to fix up factory functions for creating strided vector views * Add dot funcction that takes host scalar / host_scalar_view
cpp/include/raft/linalg/dot.cuh
Outdated
void dot(const raft::handle_t& handle, | ||
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x, | ||
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y, | ||
ElementType* out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for the host output, we probably should drop this overload. Sorry for being confusing here. I think it makes more sense to accept a host scalar by value for functions like axpy where the scalar is an input. For output on host, I think we should stick to the mdspan scalar wrappers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in latest commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks again @benfred!
@gpucibot merge |
Closes #805