Skip to content

Commit

Permalink
Allow topk larger than 1024 in CAGRA (rapidsai#2097)
Browse files Browse the repository at this point in the history
This change allows CAGRA search to have an arbitrarily large top-k, instead of being limited to 1024 like in the previous code.

This works by using the multi-kernel search path, and replacing the _cuann_find_topk code with the matrix::select_k code - which can handle large K values.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#2097
  • Loading branch information
benfred authored Jan 23, 2024
1 parent d89ab1b commit 0586fc3
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 36 deletions.
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void search_main(raft::resources const& res,
factory<T, internal_IdxT, DistanceT, CagraSampleFilterT_s>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));
plan->check(topk);

RAFT_LOG_DEBUG("Cagra search");
const uint32_t max_queries = plan->max_queries;
Expand Down
167 changes: 135 additions & 32 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,6 +37,7 @@
#include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel
#include "utils.hpp"
#include <raft/core/logger.hpp>
#include <raft/matrix/select_k.cuh>
#include <raft/util/cuda_rt_essentials.hpp>
#include <raft/util/cudart_utils.hpp> // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp

Expand Down Expand Up @@ -653,6 +654,12 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
rmm::device_scalar<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
rmm::device_uvector<uint32_t> topk_workspace;

// temporary storage for _find_topk
rmm::device_uvector<float> input_keys_storage;
rmm::device_uvector<float> output_keys_storage;
rmm::device_uvector<INDEX_T> input_values_storage;
rmm::device_uvector<INDEX_T> output_values_storage;

search(raft::resources const& res,
search_params params,
int64_t dim,
Expand All @@ -665,7 +672,11 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
parent_node_list(0, resource::get_cuda_stream(res)),
topk_hint(0, resource::get_cuda_stream(res)),
topk_workspace(0, resource::get_cuda_stream(res)),
terminate_flag(resource::get_cuda_stream(res))
terminate_flag(resource::get_cuda_stream(res)),
input_keys_storage(0, resource::get_cuda_stream(res)),
output_keys_storage(0, resource::get_cuda_stream(res)),
input_values_storage(0, resource::get_cuda_stream(res)),
output_values_storage(0, resource::get_cuda_stream(res))
{
set_params(res);
}
Expand Down Expand Up @@ -695,6 +706,98 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {

~search() {}

inline void _find_topk(raft::resources const& handle,
uint32_t topK,
uint32_t sizeBatch,
uint32_t numElements,
const float* inputKeys, // [sizeBatch, ldIK,]
uint32_t ldIK, // (*) ldIK >= numElements
const INDEX_T* inputVals, // [sizeBatch, ldIV,]
uint32_t ldIV, // (*) ldIV >= numElements
float* outputKeys, // [sizeBatch, ldOK,]
uint32_t ldOK, // (*) ldOK >= topK
INDEX_T* outputVals, // [sizeBatch, ldOV,]
uint32_t ldOV, // (*) ldOV >= topK
void* workspace,
bool sort,
uint32_t* hints)
{
auto stream = resource::get_cuda_stream(handle);

// _cuann_find_topk right now is limited to a max-k of 1024.
// RAFT has a matrix::select_k function - which handles arbitrary sized values of k,
// but doesn't accept strided inputs unlike _cuann_find_topk
// The multi-kernel search path requires strided access - since its cleverly allocating memory
// (layout described in the search_plan_impl function below), such that both the
// neighbors and the internal_topk are adjacent - in a double buffered format.
// Since this layout doesn't work with the matrix::select_k code - we have to copy
// over to a contiguous (non-strided) access to handle topk larger than 1024, and
// potentially also copy back to a strided layout afterwards
if (topK <= 1024) {
return _cuann_find_topk(topK,
sizeBatch,
numElements,
inputKeys,
ldIK,
inputVals,
ldIV,
outputKeys,
ldOK,
outputVals,
ldOV,
workspace,
sort,
hints,
stream);
}

if (ldIK > numElements) {
if (input_keys_storage.size() != sizeBatch * numElements) {
input_keys_storage.resize(sizeBatch * numElements, stream);
}
batched_memcpy(
input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream);
inputKeys = input_keys_storage.data();
}

if (ldIV > numElements) {
if (input_values_storage.size() != sizeBatch * numElements) {
input_values_storage.resize(sizeBatch * numElements, stream);
}

batched_memcpy(
input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream);
inputVals = input_values_storage.data();
}

if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) {
output_keys_storage.resize(sizeBatch * topK, stream);
}

if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) {
output_values_storage.resize(sizeBatch * topK, stream);
}

raft::matrix::select_k<float, INDEX_T>(
handle,
raft::make_device_matrix_view<const float, int64_t>(inputKeys, sizeBatch, numElements),
raft::make_device_matrix_view<const INDEX_T, int64_t>(inputVals, sizeBatch, numElements),
raft::make_device_matrix_view<float, int64_t>(
ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK),
raft::make_device_matrix_view<INDEX_T, int64_t>(
ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK),
true, // select_min
sort);

if (ldOK > topK) {
batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream);
}

if (ldOV > topK) {
batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream);
}
}

void operator()(raft::resources const& res,
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
Expand Down Expand Up @@ -746,21 +849,21 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
unsigned iter = 0;
while (1) {
// Make an index list of internal top-k nodes
_cuann_find_topk(itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr,
stream);
_find_topk(res,
itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr);

// termination (1)
if ((iter + 1 == max_iterations)) {
Expand Down Expand Up @@ -841,21 +944,21 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {

result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size;
result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size;
_cuann_find_topk(itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances_ptr,
result_buffer_allocation_size,
result_indices_ptr,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr,
stream);
_find_topk(res,
itopk_size,
num_queries,
result_buffer_size,
result_distances.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
result_distances_ptr,
result_buffer_allocation_size,
result_indices_ptr,
result_buffer_allocation_size,
topk_workspace.data(),
true,
top_hint_ptr);
} else {
// Remove parent bit in search results
remove_parent_bit(
Expand Down
10 changes: 7 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ struct search_plan_impl_base : public search_params {
if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) {
algo = search_algo::SINGLE_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting single-cta");
} else {
} else if (topk <= 1024) {
algo = search_algo::MULTI_CTA;
RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta");
} else {
algo = search_algo::MULTI_KERNEL;
RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel");
}
}
}
Expand Down Expand Up @@ -255,15 +258,16 @@ struct search_plan_impl : public search_plan_impl_base {
virtual void check(const uint32_t topk)
{
// For single-CTA and multi kernel
RAFT_EXPECTS(topk <= itopk_size, "topk must be smaller than itopk_size = %lu", itopk_size);
RAFT_EXPECTS(
topk <= itopk_size, "topk = %u must be smaller than itopk_size = %lu", topk, itopk_size);
}

inline void check_params()
{
std::string error_message = "";

if (itopk_size > 1024) {
if (algo == search_algo::MULTI_CTA) {
if ((algo == search_algo::MULTI_CTA) || (algo == search_algo::MULTI_KERNEL)) {
} else {
error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) +
") must be smaller or equal to 1024");
Expand Down
20 changes: 20 additions & 0 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand Down Expand Up @@ -496,6 +497,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;
search_params.hashmap_mode = cagra::hash_mode::HASH;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
Expand Down Expand Up @@ -611,6 +613,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam<AnnCagraInputs> {
search_params.algo = ps.algo;
search_params.max_queries = ps.max_queries;
search_params.team_size = ps.team_size;
search_params.itopk_size = ps.itopk_size;
search_params.hashmap_mode = cagra::hash_mode::HASH;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
Expand Down Expand Up @@ -818,6 +821,23 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 =
raft::util::itertools::product<AnnCagraInputs>({100},
{20000},
{32},
{2048}, // k
{graph_build_algo::NN_DESCENT},
{search_algo::AUTO},
{10},
{0},
{4096}, // itopk_size
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

return inputs;
}

Expand Down

0 comments on commit 0586fc3

Please sign in to comment.