Skip to content

Commit

Permalink
[REVIEW] Enable distance return for NN Descent (#2345)
Browse files Browse the repository at this point in the history
- Enable NN Descent to return the distances array as well (previously only returning indices array)
  - Added a `return_distances` flag in `index_params`. When set to 1 (true), allocates a distance array to return.
- Test for checking distances recall compared to naive knn

Authors:
  - Jinsol Park (https://github.com/jinsolp)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #2345
  • Loading branch information
jinsolp authored Jun 14, 2024
1 parent 074055f commit 877644a
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 14 deletions.
45 changes: 40 additions & 5 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <omp.h>

#include <limits>
#include <optional>
#include <queue>
#include <random>

Expand Down Expand Up @@ -217,6 +218,7 @@ struct BuildConfig {
// If internal_node_degree == 0, the value of node_degree will be assigned to it
size_t max_iterations{50};
float termination_threshold{0.0001};
size_t output_graph_degree{32};
};

template <typename Index_t>
Expand Down Expand Up @@ -345,7 +347,11 @@ class GNND {
GNND(const GNND&) = delete;
GNND& operator=(const GNND&) = delete;

void build(Data_t* data, const Index_t nrow, Index_t* output_graph);
void build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances);
~GNND() = default;
using ID_t = InternalID_t<Index_t>;

Expand Down Expand Up @@ -1212,7 +1218,11 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* output_graph)
void GNND<Data_t, Index_t>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances)
{
using input_t = typename std::remove_const<Data_t>::type;

Expand Down Expand Up @@ -1338,6 +1348,16 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
// Reuse graph_.h_dists as the buffer for shrink the lists in graph
static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t));
if (return_distances) {
for (size_t i = 0; i < (size_t)nrow_; i++) {
raft::copy(output_distances + i * build_config_.output_graph_degree,
graph_.h_dists.data_handle() + i * build_config_.node_degree,
build_config_.output_graph_degree,
raft::resource::get_cuda_stream(res));
}
}
Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();
#pragma omp parallel for
Expand Down Expand Up @@ -1410,10 +1430,24 @@ void build(raft::resources const& res,
.node_degree = extended_graph_degree,
.internal_node_degree = extended_intermediate_degree,
.max_iterations = params.max_iterations,
.termination_threshold = params.termination_threshold};
.termination_threshold = params.termination_threshold,
.output_graph_degree = params.graph_degree};
GNND<const T, int> nnd(res, build_config);
nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle());
if (idx.distances().has_value() || !params.return_distances) {
nnd.build(dataset.data_handle(),
dataset.extent(0),
int_graph.data_handle(),
params.return_distances,
idx.distances()
.value_or(raft::make_device_matrix<float, int64_t>(res, 0, 0).view())
.data_handle());
} else {
RAFT_EXPECTS(!params.return_distances,
"Distance view not allocated. Using return_distances set to true requires "
"distance view to be allocated.");
}
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) {
Expand Down Expand Up @@ -1444,7 +1478,8 @@ index<IdxT> build(raft::resources const& res,
graph_degree = intermediate_degree;
}
index<IdxT> idx{res, dataset.extent(0), static_cast<int64_t>(graph_degree)};
index<IdxT> idx{
res, dataset.extent(0), static_cast<int64_t>(graph_degree), params.return_distances};
build(res, params, dataset, idx);
Expand Down
38 changes: 34 additions & 4 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@

#include "ann_types.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>

#include <optional>

namespace raft::neighbors::experimental::nn_descent {
/**
* @ingroup nn-descent
Expand All @@ -51,6 +55,7 @@ struct index_params : ann::index_params {
size_t intermediate_graph_degree = 128; // Degree of input graph for pruning.
size_t max_iterations = 20; // Number of nn-descent iterations.
float termination_threshold = 0.0001; // Termination threshold of nn-descent.
bool return_distances = false; // return distances if true
};

/**
Expand Down Expand Up @@ -79,14 +84,20 @@ struct index : ann::index {
* @param res raft::resources is an object mangaging resources
* @param n_rows number of rows in knn-graph
* @param n_cols number of cols in knn-graph
* @param return_distances whether to allocate and get distances information
*/
index(raft::resources const& res, int64_t n_rows, int64_t n_cols)
index(raft::resources const& res, int64_t n_rows, int64_t n_cols, bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(n_rows, n_cols)},
graph_view_{graph_.view()}
graph_view_{graph_.view()},
return_distances_(return_distances)
{
if (return_distances) {
distances_ = raft::make_device_matrix<float, int64_t>(res_, n_rows, n_cols);
distances_view_ = distances_.value().view();
}
}

/**
Expand All @@ -98,14 +109,23 @@ struct index : ann::index {
*
* @param res raft::resources is an object mangaging resources
* @param graph_view raft::host_matrix_view<IdxT, int64_t, raft::row_major> for storing knn-graph
* @param distances_view std::optional<raft::device_matrix_view<T, int64_t, row_major>> for
* storing knn-graph distances
* @param return_distances whether to allocate and get distances information
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view)
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(0, 0)},
graph_view_{graph_view}
distances_{raft::make_device_matrix<float, int64_t>(res_, 0, 0)},
graph_view_{graph_view},
distances_view_(distances_view),
return_distances_(return_distances)
{
}

Expand Down Expand Up @@ -133,6 +153,13 @@ struct index : ann::index {
return graph_view_;
}

/** neighborhood graph distances [size, graph-degree] */
[[nodiscard]] inline auto distances() noexcept
-> std::optional<device_matrix_view<float, int64_t, row_major>>
{
return distances_view_;
}

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
index(index&&) = default;
Expand All @@ -144,8 +171,11 @@ struct index : ann::index {
raft::resources const& res_;
raft::distance::DistanceType metric_;
raft::host_matrix<IdxT, int64_t, row_major> graph_; // graph to return for non-int IdxT
std::optional<raft::device_matrix<float, int64_t, row_major>> distances_;
raft::host_matrix_view<IdxT, int64_t, row_major>
graph_view_; // view of graph for user provided matrix
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances_view_;
bool return_distances_;
};

/** @} */
Expand Down
33 changes: 28 additions & 5 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
*/
#pragma once

#include "../test_utils.cuh"
#include "ann_utils.cuh"

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/nn_descent.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/itertools.hpp>

#include <raft_internal/neighbors/naive_knn.cuh>
Expand Down Expand Up @@ -65,7 +65,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
{
size_t queries_size = ps.n_rows * ps.graph_degree;
std::vector<IdxT> indices_NNDescent(queries_size);
std::vector<DistanceT> distances_NNDescent(queries_size);
std::vector<IdxT> indices_naive(queries_size);
std::vector<DistanceT> distances_naive(queries_size);

{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
Expand All @@ -81,6 +83,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
ps.graph_degree,
ps.metric);
update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_);
update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

Expand All @@ -91,6 +94,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
index_params.graph_degree = ps.graph_degree;
index_params.intermediate_graph_degree = 2 * ps.graph_degree;
index_params.max_iterations = 100;
index_params.return_distances = true;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand All @@ -102,20 +106,39 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_host_view);
update_host(
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}

} else {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_view);
update_host(
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}
};
}
resource::sync_stream(handle_);
}

double min_recall = ps.min_recall;
EXPECT_TRUE(eval_recall(
indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall));
EXPECT_TRUE(eval_neighbours(indices_naive,
indices_NNDescent,
distances_naive,
distances_NNDescent,
ps.n_rows,
ps.graph_degree,
0.001,
min_recall));
}
}

Expand Down

0 comments on commit 877644a

Please sign in to comment.