Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distance return for NN Descent #2345

Merged
merged 31 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
242d725
enable nn descent dist return
jinsolp May 29, 2024
0977947
change bool to int
jinsolp May 29, 2024
ea30245
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 29, 2024
a7749ac
add test for distance
jinsolp May 30, 2024
a753e45
test for indices and distances with one func
jinsolp May 30, 2024
1d96e17
fix styling
jinsolp May 30, 2024
66a7678
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 30, 2024
11379f7
change return_distances to bool
jinsolp May 31, 2024
e993d7c
change distances to optional
jinsolp May 31, 2024
4578dc0
fix styling:
jinsolp May 31, 2024
c9048a5
handle bad access error
jinsolp Jun 5, 2024
bdb74a5
remove unnecessary dist allocation
jinsolp Jun 5, 2024
c71e97a
change to device matrix
jinsolp Jun 5, 2024
0327ce5
change template param for index
jinsolp Jun 5, 2024
3f49752
update test
jinsolp Jun 6, 2024
50e2cf8
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 6, 2024
a6a8ad2
remove comment
jinsolp Jun 6, 2024
2cfda3c
fix styling
jinsolp Jun 6, 2024
2efae1a
fix header
jinsolp Jun 6, 2024
931158c
add documentation for return_distances
jinsolp Jun 6, 2024
9f17b5c
add documentation
jinsolp Jun 7, 2024
e33019c
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 7, 2024
e28c2f9
add tparam doc
jinsolp Jun 10, 2024
2d74f89
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 11, 2024
fd65442
remove redundancy
jinsolp Jun 12, 2024
3ce8cdf
return optional for distances()
jinsolp Jun 13, 2024
4ec3ab0
fix styling
jinsolp Jun 13, 2024
b0347a7
remove raft_expects
jinsolp Jun 13, 2024
84abd96
fix type template and bring back raft_expects
jinsolp Jun 13, 2024
75b82df
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 13, 2024
ec44b4d
fix raft_expects
jinsolp Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <omp.h>

#include <limits>
#include <optional>
#include <queue>
#include <random>

Expand Down Expand Up @@ -217,6 +218,7 @@ struct BuildConfig {
// If internal_node_degree == 0, the value of node_degree will be assigned to it
size_t max_iterations{50};
float termination_threshold{0.0001};
size_t output_graph_degree{32};
};

template <typename Index_t>
Expand Down Expand Up @@ -345,7 +347,11 @@ class GNND {
GNND(const GNND&) = delete;
GNND& operator=(const GNND&) = delete;

void build(Data_t* data, const Index_t nrow, Index_t* output_graph);
void build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances);
~GNND() = default;
using ID_t = InternalID_t<Index_t>;

Expand Down Expand Up @@ -1212,7 +1218,11 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* output_graph)
void GNND<Data_t, Index_t>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances)
{
using input_t = typename std::remove_const<Data_t>::type;

Expand Down Expand Up @@ -1338,6 +1348,16 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out

// Reuse graph_.h_dists as the buffer for shrink the lists in graph
static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t));

if (return_distances) {
for (size_t i = 0; i < (size_t)nrow_; i++) {
raft::copy(output_distances + i * build_config_.output_graph_degree,
graph_.h_dists.data_handle() + i * build_config_.node_degree,
build_config_.output_graph_degree,
raft::resource::get_cuda_stream(res));
}
}

Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();

#pragma omp parallel for
Expand Down Expand Up @@ -1410,10 +1430,24 @@ void build(raft::resources const& res,
.node_degree = extended_graph_degree,
.internal_node_degree = extended_intermediate_degree,
.max_iterations = params.max_iterations,
.termination_threshold = params.termination_threshold};
.termination_threshold = params.termination_threshold,
.output_graph_degree = params.graph_degree};

GNND<const T, int> nnd(res, build_config);
nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle());

if (idx.distances().has_value() || !params.return_distances) {
nnd.build(dataset.data_handle(),
dataset.extent(0),
int_graph.data_handle(),
params.return_distances,
idx.distances()
.value_or(raft::make_device_matrix<float, int64_t>(res, 0, 0).view())
.data_handle());
} else {
RAFT_EXPECTS(!params.return_distances,
"Distance view not allocated. Using return_distances set to true requires "
"distance view to be allocated.");
}

#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) {
Expand Down Expand Up @@ -1444,7 +1478,8 @@ index<IdxT> build(raft::resources const& res,
graph_degree = intermediate_degree;
}

index<IdxT> idx{res, dataset.extent(0), static_cast<int64_t>(graph_degree)};
index<IdxT> idx{
res, dataset.extent(0), static_cast<int64_t>(graph_degree), params.return_distances};

build(res, params, dataset, idx);

Expand Down
38 changes: 34 additions & 4 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@

#include "ann_types.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>

#include <optional>

namespace raft::neighbors::experimental::nn_descent {
/**
* @ingroup nn-descent
Expand All @@ -51,6 +55,7 @@ struct index_params : ann::index_params {
size_t intermediate_graph_degree = 128; // Degree of input graph for pruning.
size_t max_iterations = 20; // Number of nn-descent iterations.
float termination_threshold = 0.0001; // Termination threshold of nn-descent.
bool return_distances = false; // return distances if true
};

/**
Expand Down Expand Up @@ -79,14 +84,20 @@ struct index : ann::index {
* @param res raft::resources is an object mangaging resources
* @param n_rows number of rows in knn-graph
* @param n_cols number of cols in knn-graph
* @param return_distances whether to allocate and get distances information
*/
index(raft::resources const& res, int64_t n_rows, int64_t n_cols)
index(raft::resources const& res, int64_t n_rows, int64_t n_cols, bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(n_rows, n_cols)},
graph_view_{graph_.view()}
graph_view_{graph_.view()},
return_distances_(return_distances)
{
if (return_distances) {
distances_ = raft::make_device_matrix<float, int64_t>(res_, n_rows, n_cols);
distances_view_ = distances_.value().view();
}
}

/**
Expand All @@ -98,14 +109,23 @@ struct index : ann::index {
*
* @param res raft::resources is an object mangaging resources
* @param graph_view raft::host_matrix_view<IdxT, int64_t, raft::row_major> for storing knn-graph
* @param distances_view std::optional<raft::device_matrix_view<T, int64_t, row_major>> for
* storing knn-graph distances
* @param return_distances whether to allocate and get distances information
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view)
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(0, 0)},
graph_view_{graph_view}
distances_{raft::make_device_matrix<float, int64_t>(res_, 0, 0)},
graph_view_{graph_view},
distances_view_(distances_view),
return_distances_(return_distances)
{
}

Expand Down Expand Up @@ -133,6 +153,13 @@ struct index : ann::index {
return graph_view_;
}

/** neighborhood graph distances [size, graph-degree] */
[[nodiscard]] inline auto distances() noexcept
-> std::optional<device_matrix_view<float, int64_t, row_major>>
{
return distances_view_;
}

// Don't allow copying the index for performance reasons (try avoiding copying data)
index(const index&) = delete;
index(index&&) = default;
Expand All @@ -144,8 +171,11 @@ struct index : ann::index {
raft::resources const& res_;
raft::distance::DistanceType metric_;
raft::host_matrix<IdxT, int64_t, row_major> graph_; // graph to return for non-int IdxT
std::optional<raft::device_matrix<float, int64_t, row_major>> distances_;
raft::host_matrix_view<IdxT, int64_t, row_major>
graph_view_; // view of graph for user provided matrix
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances_view_;
bool return_distances_;
};

/** @} */
Expand Down
33 changes: 28 additions & 5 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
*/
#pragma once

#include "../test_utils.cuh"
#include "ann_utils.cuh"

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/nn_descent.cuh>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/itertools.hpp>

#include <raft_internal/neighbors/naive_knn.cuh>
Expand Down Expand Up @@ -65,7 +65,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
{
size_t queries_size = ps.n_rows * ps.graph_degree;
std::vector<IdxT> indices_NNDescent(queries_size);
std::vector<DistanceT> distances_NNDescent(queries_size);
std::vector<IdxT> indices_naive(queries_size);
std::vector<DistanceT> distances_naive(queries_size);

{
rmm::device_uvector<DistanceT> distances_naive_dev(queries_size, stream_);
Expand All @@ -81,6 +83,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
ps.graph_degree,
ps.metric);
update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_);
update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_);
resource::sync_stream(handle_);
}

Expand All @@ -91,6 +94,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
index_params.graph_degree = ps.graph_degree;
index_params.intermediate_graph_degree = 2 * ps.graph_degree;
index_params.max_iterations = 100;
index_params.return_distances = true;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand All @@ -102,20 +106,39 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_host_view);
update_host(
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}

} else {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_view);
update_host(
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}
};
}
resource::sync_stream(handle_);
}

double min_recall = ps.min_recall;
EXPECT_TRUE(eval_recall(
indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall));
EXPECT_TRUE(eval_neighbours(indices_naive,
indices_NNDescent,
distances_naive,
distances_NNDescent,
ps.n_rows,
ps.graph_degree,
0.001,
min_recall));
}
}

Expand Down
Loading