Skip to content

Commit 6981a08

Browse files
authored
Add ability to allocate with RMM to the c-api and rust api (#56)
Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Divye Gala (https://github.com/divyegala) - Corey J. Nolet (https://github.com/cjnolet) - Ray Douglass (https://github.com/raydouglass) URL: #56
1 parent e5d5e3a commit 6981a08

File tree

9 files changed

+126
-28
lines changed

9 files changed

+126
-28
lines changed

ci/build_rust.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,4 @@ rapids-mamba-retry install \
3535
libcuvs \
3636
libraft
3737

38-
export EXTRA_CMAKE_ARGS=""
3938
bash ./build.sh rust

cpp/include/cuvs/core/c_api.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,50 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res);
8383
*/
8484
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream);
8585

86+
/**
87+
* @brief Get the cudaStream_t from a cuvsResources_t t
88+
*
89+
* @param[in] res cuvsResources_t opaque C handle
90+
* @param[out] stream cudaStream_t stream to queue CUDA kernels
91+
* @return cuvsError_t
92+
*/
93+
cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream);
94+
95+
/**
96+
* @brief Syncs the current CUDA stream on the resources object
97+
*
98+
* @param[in] res cuvsResources_t opaque C handle
99+
* @return cuvsError_t
100+
*/
101+
cuvsError_t cuvsStreamSync(cuvsResources_t res);
102+
/** @} */
103+
104+
/**
105+
* @defgroup memory_c cuVS Memory Allocation
106+
* @{
107+
*/
108+
109+
/**
110+
* @brief Allocates device memory using RMM
111+
*
112+
*
113+
* @param[in] res cuvsResources_t opaque C handle
114+
* @param[out] ptr Pointer to allocated device memory
115+
* @param[in] bytes Size in bytes to allocate
116+
* @return cuvsError_t
117+
*/
118+
cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes);
119+
120+
/**
121+
* @brief Deallocates device memory using RMM
122+
*
123+
* @param[in] res cuvsResources_t opaque C handle
124+
* @param[in] ptr Pointer to allocated device memory to free
125+
* @param[in] bytes Size in bytes to allocate
126+
* @return cuvsError_t
127+
*/
128+
cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes);
129+
86130
/** @} */
87131

88132
#ifdef __cplusplus

cpp/src/core/c_api.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <raft/core/resource/cuda_stream.hpp>
2222
#include <raft/core/resources.hpp>
2323
#include <rmm/cuda_stream_view.hpp>
24+
#include <rmm/mr/device/per_device_resource.hpp>
2425
#include <thread>
2526

2627
extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
@@ -47,6 +48,40 @@ extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
4748
});
4849
}
4950

51+
extern "C" cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream)
52+
{
53+
return cuvs::core::translate_exceptions([=] {
54+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
55+
*stream = raft::resource::get_cuda_stream(*res_ptr);
56+
});
57+
}
58+
59+
extern "C" cuvsError_t cuvsStreamSync(cuvsResources_t res)
60+
{
61+
return cuvs::core::translate_exceptions([=] {
62+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
63+
raft::resource::sync_stream(*res_ptr);
64+
});
65+
}
66+
67+
extern "C" cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes)
68+
{
69+
return cuvs::core::translate_exceptions([=] {
70+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
71+
auto mr = rmm::mr::get_current_device_resource();
72+
*ptr = mr->allocate(bytes, raft::resource::get_cuda_stream(*res_ptr));
73+
});
74+
}
75+
76+
extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes)
77+
{
78+
return cuvs::core::translate_exceptions([=] {
79+
auto res_ptr = reinterpret_cast<raft::resources*>(res);
80+
auto mr = rmm::mr::get_current_device_resource();
81+
mr->deallocate(ptr, bytes, raft::resource::get_cuda_stream(*res_ptr));
82+
});
83+
}
84+
5085
thread_local std::string last_error_text = "";
5186

5287
extern "C" const char* cuvsGetLastErrorText()

rust/cuvs-sys/build.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ fn main() {
101101
.allowlist_type("(cuvs|cagra|DL).*")
102102
.allowlist_function("(cuvs|cagra).*")
103103
.rustified_enum("(cuvs|cagra|DL).*")
104-
// also need some basic cuda mem functions
105-
// (TODO: should we be adding in RMM support instead here?)
106-
.allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)")
104+
// also need some basic cuda mem functions for copying data
105+
.allowlist_function("(cudaMemcpyAsync|cudaMemcpy)")
107106
.rustified_enum("cudaError")
108107
.generate()
109108
.expect("Unable to generate cagra_c bindings")

rust/cuvs/src/cagra/index.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Index {
5454
/// Creates a new empty index
5555
pub fn new() -> Result<Index> {
5656
unsafe {
57-
let mut index = core::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
57+
let mut index = std::mem::MaybeUninit::<ffi::cuvsCagraIndex_t>::uninit();
5858
check_cuvs(ffi::cuvsCagraIndexCreate(index.as_mut_ptr()))?;
5959
Ok(Index(index.assume_init()))
6060
}

rust/cuvs/src/cagra/index_params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl IndexParams {
2727
/// Returns a new IndexParams
2828
pub fn new() -> Result<IndexParams> {
2929
unsafe {
30-
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
30+
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
3131
check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
3232
Ok(IndexParams(params.assume_init()))
3333
}

rust/cuvs/src/cagra/search_params.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ impl SearchParams {
2828
/// Returns a new SearchParams object
2929
pub fn new() -> Result<SearchParams> {
3030
unsafe {
31-
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
31+
let mut params = std::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
3232
check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
3333
Ok(SearchParams(params.assume_init()))
3434
}

rust/cuvs/src/dlpack.rs

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
use std::convert::From;
1818

19-
use crate::error::{check_cuda, Result};
19+
use crate::error::{check_cuda, check_cuvs, Result};
2020
use crate::resources::Resources;
2121

2222
/// ManagedTensor is a wrapper around a dlpack DLManagedTensor object.
@@ -33,36 +33,27 @@ impl ManagedTensor {
3333
&self.0 as *const _ as *mut _
3434
}
3535

36-
fn bytes(&self) -> usize {
37-
// figure out how many bytes to allocate
38-
let mut bytes: usize = 1;
39-
for x in 0..self.0.dl_tensor.ndim {
40-
bytes *= unsafe { (*self.0.dl_tensor.shape.add(x as usize)) as usize };
41-
}
42-
bytes *= (self.0.dl_tensor.dtype.bits / 8) as usize;
43-
bytes
44-
}
45-
4636
/// Creates a new ManagedTensor on the current GPU device, and copies
4737
/// the data into it.
48-
pub fn to_device(&self, _res: &Resources) -> Result<ManagedTensor> {
38+
pub fn to_device(&self, res: &Resources) -> Result<ManagedTensor> {
4939
unsafe {
50-
let bytes = self.bytes();
40+
let bytes = dl_tensor_bytes(&self.0.dl_tensor);
5141
let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();
5242

5343
// allocate storage, copy over
54-
check_cuda(ffi::cudaMalloc(&mut device_data as *mut _, bytes))?;
55-
check_cuda(ffi::cudaMemcpy(
44+
check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?;
45+
46+
check_cuda(ffi::cudaMemcpyAsync(
5647
device_data,
5748
self.0.dl_tensor.data,
5849
bytes,
5950
ffi::cudaMemcpyKind_cudaMemcpyDefault,
51+
res.get_cuda_stream()?,
6052
))?;
6153

6254
let mut ret = self.0.clone();
6355
ret.dl_tensor.data = device_data;
64-
// call cudaFree automatically to clean up data
65-
ret.deleter = Some(cuda_free_tensor);
56+
ret.deleter = Some(rmm_free_tensor);
6657
ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA;
6758

6859
Ok(ManagedTensor(ret))
@@ -80,21 +71,32 @@ impl ManagedTensor {
8071
arr: &mut ndarray::ArrayBase<S, D>,
8172
) -> Result<()> {
8273
unsafe {
83-
let bytes = self.bytes();
74+
let bytes = dl_tensor_bytes(&self.0.dl_tensor);
8475
check_cuda(ffi::cudaMemcpy(
8576
arr.as_mut_ptr() as *mut std::ffi::c_void,
8677
self.0.dl_tensor.data,
8778
bytes,
8879
ffi::cudaMemcpyKind_cudaMemcpyDefault,
8980
))?;
90-
9181
Ok(())
9282
}
9383
}
9484
}
9585

96-
unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) {
97-
let _ = ffi::cudaFree((*self_).dl_tensor.data);
86+
/// Figures out how many bytes are in a DLTensor
87+
fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize {
88+
let mut bytes: usize = 1;
89+
for dim in 0..tensor.ndim {
90+
bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize };
91+
}
92+
bytes *= (tensor.dtype.bits / 8) as usize;
93+
bytes
94+
}
95+
96+
unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) {
97+
let bytes = dl_tensor_bytes(&(*self_).dl_tensor);
98+
let res = Resources::new().unwrap();
99+
let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes);
98100
}
99101

100102
/// Create a non-owning view of a Tensor from a ndarray

rust/cuvs/src/resources.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ impl Resources {
3232
}
3333
Ok(Resources(res))
3434
}
35+
36+
/// Sets the current cuda stream
37+
pub fn set_cuda_stream(&self, stream: ffi::cudaStream_t) -> Result<()> {
38+
unsafe { check_cuvs(ffi::cuvsStreamSet(self.0, stream)) }
39+
}
40+
41+
/// Gets the current cuda stream
42+
pub fn get_cuda_stream(&self) -> Result<ffi::cudaStream_t> {
43+
unsafe {
44+
let mut stream = std::mem::MaybeUninit::<ffi::cudaStream_t>::uninit();
45+
check_cuvs(ffi::cuvsStreamGet(self.0, stream.as_mut_ptr()))?;
46+
Ok(stream.assume_init())
47+
}
48+
}
49+
50+
/// Syncs the current cuda stream
51+
pub fn sync_stream(&self) -> Result<()> {
52+
unsafe { check_cuvs(ffi::cuvsStreamSync(self.0)) }
53+
}
3554
}
3655

3756
impl Drop for Resources {

0 commit comments

Comments
 (0)