-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEA] Implement matrix transpose with mdspan. #739
Conversation
trivialfis
commented
Jul 12, 2022
•
edited
Loading
edited
- Implement a transpose function that works on both column and row major matrix.
- sub-matrix is supported as well.
* Implement a transpose function that works on both column and row major matrix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks perfect to me. I just have a question to @cjnolet as we had discussed accepting the standard mdspan without the host/device
wrappers to our APIs. Is that still the case?
That's correct @divyegala. In preparation for our being able to support dispatch to multiple execution environments (e.g. CPU/GPU) in the future, we are coding our public APIs directly to |
@trivialfis with Corey's confirmation on this, can we use the vanilla mdspan API? Also we should leave out the overload that returns our implementation of |
Thank you for the review, I hid the functions into tests and changed the input parameters to vanilla mdspan. |
@@ -79,6 +80,55 @@ void transpose(math_t* inout, int n, cudaStream_t stream) | |||
}); | |||
} | |||
|
|||
template <typename T, typename LayoutPolicy> | |||
void transpose_row_major_impl(handle_t const& handle, | |||
device_matrix_view<T, LayoutPolicy> in, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should also be vanilla mdspan? I am surprised we are able to construct a device_matrix_view
with mdspan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if this should be normal span. It's a safe guard that prevents users from passing the wrong pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't be a breaking change when we support host mdspan since we will be just relaxing the constraint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might be okay on the impl
side. Because this isn't directly exposed to users, we have the freedom to change this as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_matrix_view
doesn't do any check on the pointer's underlying memory though, right? We can construct it even with a host pointer/mdspan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's why we should just accept vanilla mdspan all the way through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a compile time guard. Given a host mdspan it won't compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis can we accept vanilla mdspan on our public API, and either use pointers or vanilla mdspan all the way after that? In the future, we will discuss the design of an "execution environment" that will allow us to handle memory guards better (compile time/run time). Until then, it will be easier to use the standard and then refactor if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, will make the changes. Could you please loop me into the discussion of "execution environment"? I think that might benefit the development for XGBoost.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to using vanilla mdspan.
rerun tests |
@gpucibot merge |