Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IVF-Flat C++ example #1828

Merged
merged 13 commits into from
Sep 26, 2023
Merged
Prev Previous commit
Next Next commit
Fix dataset indices, Sync and print results
  • Loading branch information
tfeher committed Sep 18, 2023
commit c44d751feb6e9f6f1ce997add67f3ec836a4edce
67 changes: 51 additions & 16 deletions cpp/template/src/ivf_flat_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,38 @@
#include <raft/neighbors/ivf_flat.cuh>
#include <raft/random/make_blobs.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <raft/util/cudart_utils.hpp>

#include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>

// Copy the results to host and print a few samples
void print_results(raft::device_resources const& dev_resources,
tfeher marked this conversation as resolved.
Show resolved Hide resolved
raft::device_matrix_view<int64_t, int64_t> neighbors,
raft::device_matrix_view<float, int64_t> distances)
{
int64_t topk = neighbors.extent(1);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(neighbors.extent(0), topk);
auto distances_host = raft::make_host_matrix<float, int64_t>(distances.extent(0), topk);

cudaStream_t stream = raft::resource::get_cuda_stream(dev_resources);

raft::copy(neighbors_host.data_handle(), neighbors.data_handle(), neighbors.size(), stream);
raft::copy(distances_host.data_handle(), distances.data_handle(), distances.size(), stream);

// The calls to ivf_flat::search and raft::copy is asyncronous.
// We need to sync the stream before accessing the data.
raft::resource::sync_stream(dev_resources, stream);

for (int query_id = 0; query_id < 2; query_id++) {
std::cout << "Query " << query_id << " neighbor indices: ";
raft::print_host_vector("", &neighbors_host(query_id, 0), topk, std::cout);
std::cout << "Query " << query_id << " neighbor distances: ";
raft::print_host_vector("", &distances_host(query_id, 0), topk, std::cout);
}
}

void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
raft::device_matrix_view<const float, int64_t> dataset,
raft::device_matrix_view<const float, int64_t> queries)
Expand All @@ -46,7 +73,7 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
<< index.size() << std::endl;

// Create output arrays.
int64_t topk = 12;
int64_t topk = 10;
int64_t n_queries = queries.extent(0);
auto neighbors = raft::make_device_matrix<int64_t>(dev_resources, n_queries, topk);
auto distances = raft::make_device_matrix<float>(dev_resources, n_queries, topk);
Expand All @@ -58,12 +85,18 @@ void ivf_flat_build_search_simple(raft::device_resources const& dev_resources,
// Search K nearest neighbors for each of the queries.
ivf_flat::search(
dev_resources, search_params, index, queries, neighbors.view(), distances.view());

// The call to ivf_flat::search is asyncronous. Before accessing the data, sync by calling
// raft::resource::sync_stream(dev_resources);

print_results(dev_resources, neighbors.view(), distances.view());
}

/** Subsample the dataset to create a training set*/
raft::device_matrix<float, int64_t> subsample(
tfeher marked this conversation as resolved.
Show resolved Hide resolved
raft::device_resources const& dev_resources,
raft::device_matrix_view<const float, int64_t> dataset,
raft::device_vector_view<const int64_t, int64_t> data_indices,
float fraction)
{
int64_t n_samples = dataset.extent(0);
Expand All @@ -73,19 +106,10 @@ raft::device_matrix<float, int64_t> subsample(

int seed = 137;
raft::random::RngState rng(seed);
auto data_indices = raft::make_device_vector<int64_t>(dev_resources, n_samples);
auto train_indices = raft::make_device_vector<int64_t>(dev_resources, n_train);

thrust::counting_iterator<int64_t> first(0);
thrust::device_ptr<int64_t> ptr(data_indices.data_handle());
thrust::copy(raft::resource::get_thrust_policy(dev_resources), first, first + n_samples, ptr);

raft::random::sample_without_replacement(dev_resources,
rng,
raft::make_const_mdspan(data_indices.view()),
std::nullopt,
train_indices.view(),
std::nullopt);
raft::random::sample_without_replacement(
dev_resources, rng, data_indices, std::nullopt, train_indices.view(), std::nullopt);

raft::matrix::copy_rows(
dev_resources, dataset, trainset.view(), raft::make_const_mdspan(train_indices.view()));
Expand All @@ -99,8 +123,16 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
{
using namespace raft::neighbors;

// Define dataset indices.
auto data_indices = raft::make_device_vector<int64_t, int64_t>(dev_resources, dataset.extent(0));
thrust::counting_iterator<int64_t> first(0);
thrust::device_ptr<int64_t> ptr(data_indices.data_handle());
tfeher marked this conversation as resolved.
Show resolved Hide resolved
thrust::copy(
raft::resource::get_thrust_policy(dev_resources), first, first + dataset.extent(0), ptr);

// Sub-sample the dataset to create a training set.
auto trainset = subsample(dev_resources, dataset, 0.1);
auto trainset =
subsample(dev_resources, dataset, raft::make_const_mdspan(data_indices.view()), 0.1);

ivf_flat::index_params index_params;
index_params.n_lists = 100;
Expand All @@ -115,9 +147,7 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
<< index.size() << std::endl;

std::cout << "Filling index with the dataset vectors" << std::endl;

auto data_indices = raft::make_device_vector<int64_t, int64_t>(dev_resources, dataset.extent(1));
index = ivf_flat::extend(dev_resources,
index = ivf_flat::extend(dev_resources,
dataset,
std::make_optional(raft::make_const_mdspan(data_indices.view())),
index);
Expand All @@ -137,6 +167,11 @@ void ivf_flat_build_extend_search(raft::device_resources const& dev_resources,
// Search K nearest neighbors for each queries.
ivf_flat::search(
dev_resources, search_params, index, queries, neighbors.view(), distances.view());

// The call to ivf_flat::search is asyncronous. Before accessing the data, sync using:
// raft::resource::sync_stream(dev_resources);

print_results(dev_resources, neighbors.view(), distances.view());
}

int main()
Expand Down
Loading