Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rust bindings for CAGRA #34

Merged
merged 41 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a4ec121
Rust bindings for cuvs
benfred Feb 6, 2024
cae7cd3
share common workspace metadata
benfred Feb 6, 2024
160b363
support for building cagra index
benfred Feb 7, 2024
4e277e5
working search unittest
benfred Feb 12, 2024
0027115
cudaFree
benfred Feb 12, 2024
e11e4e0
functioning unittest
benfred Feb 13, 2024
15aceda
add cagra example program
benfred Feb 20, 2024
2e4b477
Add resources to to_host/to_device functions
benfred Feb 20, 2024
83e202e
add From trait for converting ndarray to ManagedTensor
benfred Feb 21, 2024
1990d02
basic CI
benfred Feb 21, 2024
5c7aafa
.
benfred Feb 21, 2024
b44b81c
.
benfred Feb 26, 2024
13b1848
.
benfred Feb 26, 2024
bd8c432
.
benfred Feb 26, 2024
d411803
.
benfred Feb 26, 2024
a2e8009
remove export
benfred Feb 26, 2024
b4dc043
.
benfred Feb 26, 2024
af4b0ba
.
benfred Feb 26, 2024
e7a47de
.
benfred Feb 26, 2024
a757060
one more time
benfred Feb 27, 2024
6c54aa4
don't download python artifacts
benfred Feb 27, 2024
79560c3
one more time
benfred Feb 27, 2024
023795d
.
benfred Feb 27, 2024
d3ceacf
add libclang
benfred Feb 27, 2024
a60edc8
add clang
benfred Feb 27, 2024
db00e6f
set LIBCLANG_PATH
benfred Feb 27, 2024
82ccb25
correct libclang_path
benfred Feb 27, 2024
1ffdfd7
check in cuvs_bindings.rs
benfred Feb 27, 2024
c091196
Revert "check in cuvs_bindings.rs"
benfred Feb 28, 2024
0005294
try harder with libclang
benfred Feb 28, 2024
fc413fb
.
benfred Feb 28, 2024
8397651
.
benfred Feb 28, 2024
4c26cd7
:facepalm:
benfred Feb 29, 2024
eddab35
Merge branch 'branch-24.04' into rust_bindings
benfred Feb 29, 2024
0433e07
.
benfred Mar 4, 2024
b3a9624
Merge branch 'branch-24.04' into rust_bindings
benfred Mar 4, 2024
3eec2c3
update to latest 24.04 cagra api
benfred Mar 4, 2024
6cd44b1
Merge branch 'rust_bindings' of https://github.com/benfred/cuvs into …
benfred Mar 4, 2024
2dd4ea0
add example to README
benfred Mar 4, 2024
a781042
.
benfred Mar 4, 2024
e836117
update dependencies.yaml
benfred Mar 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working search unittest
  • Loading branch information
benfred committed Feb 27, 2024
commit 4e277e5595a68539b9cf80ef3bcdb56f9c4e4482
5 changes: 5 additions & 0 deletions rust/cuvs-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn main() {
cuvs_build.display()
);
println!("cargo:rustc-link-lib=dylib=cuvs_c");
println!("cargo:rustc-link-lib=dylib=cudart");

// we need some extra flags both to link against cuvs, and also to run bindgen
// specifically we need to:
Expand Down Expand Up @@ -100,6 +101,10 @@ fn main() {
.allowlist_type("(cuvs|cagra|DL).*")
.allowlist_function("(cuvs|cagra).*")
.rustified_enum("(cuvs|cagra|DL).*")
// also need some basic cuda mem functions
// (TODO: should we be adding in RMM support instead here?)
.allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)")
.rustified_enum("cudaError")
.generate()
.expect("Unable to generate cagra_c bindings")
.write_to_file(out_path.join("cuvs_bindings.rs"))
Expand Down
3 changes: 3 additions & 0 deletions rust/cuvs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ license.workspace = true
[dependencies]
ffi = { package = "cuvs-sys", path = "../cuvs-sys" }
ndarray = "0.15"

[dev-dependencies]
ndarray-rand = "*"
62 changes: 53 additions & 9 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use std::io::{stderr, Write};

use crate::cagra::IndexParams;
use crate::cagra::{IndexParams, SearchParams};
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
Expand All @@ -28,12 +28,12 @@ pub struct Index {

impl Index {
/// Builds a new index
pub fn build(res: Resources, params: IndexParams, dataset: ManagedTensor) -> Result<Index> {
pub fn build(res: &Resources, params: &IndexParams, dataset: ManagedTensor) -> Result<Index> {
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cagraBuild(
res.res,
params.params,
res.0,
params.0,
dataset.as_ptr(),
index.index,
))?;
Expand All @@ -51,6 +51,26 @@ impl Index {
})
}
}

pub fn search(
self,
res: &Resources,
params: &SearchParams,
queries: ManagedTensor,
neighbors: ManagedTensor,
distances: ManagedTensor,
) -> Result<()> {
unsafe {
check_cuvs(ffi::cagraSearch(
res.0,
params.0,
self.index,
queries.as_ptr(),
neighbors.as_ptr(),
distances.as_ptr(),
))
}
}
}

impl Drop for Index {
Expand All @@ -65,21 +85,45 @@ impl Drop for Index {
#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

#[test]
fn test_create_empty_index() {
Index::new().unwrap();
}

#[test]
fn test_build() {
fn test_index() {
let res = Resources::new().unwrap();
let params = IndexParams::new().unwrap();

// TODO: test a more exciting dataset
let arr = ndarray::Array::<f32, _>::zeros((128, 16));
let dataset = ManagedTensor::from_ndarray(arr);
let n_features = 16;
let dataset = ndarray::Array::<f32, _>::random((256, n_features), Uniform::new(0., 1.0));
let index = Index::build(&res, &params, ManagedTensor::from_ndarray(&dataset))
.expect("failed to create cagra index");

// use the first 4 points from the dataset as queries : will test that we get them back
// as their own nearest neighbor
let n_queries = 4;
let queries = dataset.slice(s![0..n_queries, ..]);
let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap();

let k = 10;
let neighbors =
ManagedTensor::from_ndarray(&ndarray::Array::<u32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();
let distances =
ManagedTensor::from_ndarray(&ndarray::Array::<f32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();

let search_params = SearchParams::new().unwrap();

let index = Index::build(res, params, dataset).expect("failed to create cagra index");
index
.search(&res, &search_params, queries, neighbors, distances)
.unwrap();
}
}
29 changes: 15 additions & 14 deletions rust/cuvs/src/cagra/index_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,45 @@ use std::io::{stderr, Write};
pub type BuildAlgo = ffi::cagraGraphBuildAlgo;

/// Supplemental parameters to build CAGRA Index
pub struct IndexParams {
pub params: ffi::cuvsCagraIndexParams_t,
}
pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t);

impl IndexParams {
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams {
params: params.assume_init(),
})
Ok(IndexParams(params.assume_init()))
}
}

/// Degree of input graph for pruning
pub fn set_intermediate_graph_degree(self, intermediate_graph_degree: usize) -> IndexParams {
unsafe {
(*self.params).intermediate_graph_degree = intermediate_graph_degree;
(*self.0).intermediate_graph_degree = intermediate_graph_degree;
}
self
}

/// Degree of output graph
pub fn set_graph_degree(self, graph_degree: usize) -> IndexParams {
unsafe {
(*self.params).graph_degree = graph_degree;
(*self.0).graph_degree = graph_degree;
}
self
}

/// ANN algorithm to build knn graph
pub fn set_build_algo(self, build_algo: BuildAlgo) -> IndexParams {
unsafe {
(*self.params).build_algo = build_algo;
(*self.0).build_algo = build_algo;
}
self
}

/// Number of iterations to run if building with NN_DESCENT
pub fn set_nn_descent_niter(self, nn_descent_niter: usize) -> IndexParams {
unsafe {
(*self.params).nn_descent_niter = nn_descent_niter;
(*self.0).nn_descent_niter = nn_descent_niter;
}
self
}
Expand All @@ -73,13 +69,13 @@ impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.params })
write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.0 })
}
}

impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.params) }) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsCagraIndexParamsDestroy {:?}",
Expand All @@ -103,7 +99,12 @@ mod tests {
.set_build_algo(BuildAlgo::NN_DESCENT)
.set_nn_descent_niter(10);

// make sure the setters actually updated internal representation
assert_eq!(format!("{:?}", params), "IndexParams { params: cagraIndexParams { intermediate_graph_degree: 128, graph_degree: 16, build_algo: NN_DESCENT, nn_descent_niter: 10 } }");
// make sure the setters actually updated internal representation on the c-struct
unsafe {
assert_eq!((*params.0).graph_degree, 16);
assert_eq!((*params.0).intermediate_graph_degree, 128);
assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT);
assert_eq!((*params.0).nn_descent_niter, 10);
}
}
}
44 changes: 21 additions & 23 deletions rust/cuvs/src/cagra/search_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,21 @@ pub type SearchAlgo = ffi::cagraSearchAlgo;
pub type HashMode = ffi::cagraHashMode;

/// Supplemental parameters to search CAGRA index
pub struct SearchParams {
pub params: ffi::cuvsCagraSearchParams_t,
}
pub struct SearchParams(pub ffi::cuvsCagraSearchParams_t);

impl SearchParams {
pub fn new() -> Result<SearchParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
Ok(SearchParams {
params: params.assume_init(),
})
Ok(SearchParams(params.assume_init()))
}
}

/// Maximum number of queries to search at the same time (batch size). Auto select when 0
pub fn set_max_queries(self, max_queries: usize) -> SearchParams {
unsafe {
(*self.params).max_queries = max_queries;
(*self.0).max_queries = max_queries;
}
self
}
Expand All @@ -50,87 +46,87 @@ impl SearchParams {
/// Higher values improve the search accuracy
pub fn set_itopk_size(self, itopk_size: usize) -> SearchParams {
unsafe {
(*self.params).itopk_size = itopk_size;
(*self.0).itopk_size = itopk_size;
}
self
}

/// Upper limit of search iterations. Auto select when 0.
pub fn set_max_iterations(self, max_iterations: usize) -> SearchParams {
unsafe {
(*self.params).max_iterations = max_iterations;
(*self.0).max_iterations = max_iterations;
}
self
}

/// Which search implementation to use.
pub fn set_algo(self, algo: SearchAlgo) -> SearchParams {
unsafe {
(*self.params).algo = algo;
(*self.0).algo = algo;
}
self
}

/// Number of threads used to calculate a single distance. 4, 8, 16, or 32.
pub fn set_team_size(self, team_size: usize) -> SearchParams {
unsafe {
(*self.params).team_size = team_size;
(*self.0).team_size = team_size;
}
self
}

/// Lower limit of search iterations.
pub fn set_min_iterations(self, min_iterations: usize) -> SearchParams {
unsafe {
(*self.params).min_iterations = min_iterations;
(*self.0).min_iterations = min_iterations;
}
self
}

/// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
pub fn set_thread_block_size(self, thread_block_size: usize) -> SearchParams {
unsafe {
(*self.params).thread_block_size = thread_block_size;
(*self.0).thread_block_size = thread_block_size;
}
self
}

/// Hashmap type. Auto selection when AUTO.
pub fn set_hashmap_mode(self, hashmap_mode: HashMode) -> SearchParams {
unsafe {
(*self.params).hashmap_mode = hashmap_mode;
(*self.0).hashmap_mode = hashmap_mode;
}
self
}

/// Lower limit of hashmap bit length. More than 8.
pub fn set_hashmap_min_bitlen(self, hashmap_min_bitlen: usize) -> SearchParams {
unsafe {
(*self.params).hashmap_min_bitlen = hashmap_min_bitlen;
(*self.0).hashmap_min_bitlen = hashmap_min_bitlen;
}
self
}

/// Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
pub fn set_hashmap_max_fill_rate(self, hashmap_max_fill_rate: f32) -> SearchParams {
unsafe {
(*self.params).hashmap_max_fill_rate = hashmap_max_fill_rate;
(*self.0).hashmap_max_fill_rate = hashmap_max_fill_rate;
}
self
}

/// Number of iterations of initial random seed node selection. 1 or more.
pub fn set_num_random_samplings(self, num_random_samplings: u32) -> SearchParams {
unsafe {
(*self.params).num_random_samplings = num_random_samplings;
(*self.0).num_random_samplings = num_random_samplings;
}
self
}

/// Bit mask used for initial random seed node selection.
pub fn set_rand_xor_mask(self, rand_xor_mask: u64) -> SearchParams {
unsafe {
(*self.params).rand_xor_mask = rand_xor_mask;
(*self.0).rand_xor_mask = rand_xor_mask;
}
self
}
Expand All @@ -140,15 +136,13 @@ impl fmt::Debug for SearchParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "SearchParams {{ params: {:?} }}", unsafe {
*self.params
})
write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
}
}

impl Drop for SearchParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.params) }) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsCagraSearchParamsDestroy {:?}",
Expand All @@ -165,6 +159,10 @@ mod tests {

#[test]
fn test_search_params() {
let params = SearchParams::new().unwrap();
let params = SearchParams::new().unwrap().set_itopk_size(128);

unsafe {
assert_eq!((*params.0).itopk_size, 128);
}
}
}
Loading