Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distance return for NN Descent #2345

Merged
merged 31 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
242d725
enable nn descent dist return
jinsolp May 29, 2024
0977947
change bool to int
jinsolp May 29, 2024
ea30245
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 29, 2024
a7749ac
add test for distance
jinsolp May 30, 2024
a753e45
test for indices and distances with one func
jinsolp May 30, 2024
1d96e17
fix styling
jinsolp May 30, 2024
66a7678
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 30, 2024
11379f7
change return_distances to bool
jinsolp May 31, 2024
e993d7c
change distances to optional
jinsolp May 31, 2024
4578dc0
fix styling:
jinsolp May 31, 2024
c9048a5
handle bad access error
jinsolp Jun 5, 2024
bdb74a5
remove unnecessary dist allocation
jinsolp Jun 5, 2024
c71e97a
change to device matrix
jinsolp Jun 5, 2024
0327ce5
change template param for index
jinsolp Jun 5, 2024
3f49752
update test
jinsolp Jun 6, 2024
50e2cf8
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 6, 2024
a6a8ad2
remove comment
jinsolp Jun 6, 2024
2cfda3c
fix styling
jinsolp Jun 6, 2024
2efae1a
fix header
jinsolp Jun 6, 2024
931158c
add documentation for return_distances
jinsolp Jun 6, 2024
9f17b5c
add documentation
jinsolp Jun 7, 2024
e33019c
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 7, 2024
e28c2f9
add tparam doc
jinsolp Jun 10, 2024
2d74f89
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 11, 2024
fd65442
remove redundancy
jinsolp Jun 12, 2024
3ce8cdf
return optional for distances()
jinsolp Jun 13, 2024
4ec3ab0
fix styling
jinsolp Jun 13, 2024
b0347a7
remove raft_expects
jinsolp Jun 13, 2024
84abd96
fix type template and bring back raft_expects
jinsolp Jun 13, 2024
75b82df
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 13, 2024
ec44b4d
fix raft_expects
jinsolp Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix styling
  • Loading branch information
jinsolp committed May 30, 2024
commit 1d96e1761aff5aeb25c571e368c7ce97526e5da1
31 changes: 22 additions & 9 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ class GNND {
GNND(const GNND&) = delete;
GNND& operator=(const GNND&) = delete;

void build(Data_t* data, const Index_t nrow, Index_t* output_graph, bool return_distances, DistData_t *output_distances);
void build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances);
~GNND() = default;
using ID_t = InternalID_t<Index_t>;

Expand Down Expand Up @@ -1212,7 +1216,11 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
}

template <typename Data_t, typename Index_t>
void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* output_graph, bool return_distances, DistData_t *output_distances)
void GNND<Data_t, Index_t>::build(Data_t* data,
const Index_t nrow,
Index_t* output_graph,
bool return_distances,
DistData_t* output_distances)
{
using input_t = typename std::remove_const<Data_t>::type;

Expand Down Expand Up @@ -1339,7 +1347,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
// Reuse graph_.h_dists as the buffer for shrink the lists in graph
static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t));

if(return_distances) {
if (return_distances) {
raft::copy(output_distances,
graph_.h_dists.data_handle(),
nrow_ * build_config_.node_degree,
Expand Down Expand Up @@ -1412,9 +1420,9 @@ void build(raft::resources const& res,

auto int_graph = raft::make_host_matrix<int, int64_t, row_major>(
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
dataset.extent(0), static_cast<int64_t>(extended_graph_degree));
auto distances_graph = raft::make_host_matrix<DistData_t, int64_t, row_major>(0,0);
if(params.return_distances) {

auto distances_graph = raft::make_host_matrix<DistData_t, int64_t, row_major>(0, 0);
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
if (params.return_distances) {
distances_graph = raft::make_host_matrix<DistData_t, int64_t, row_major>(
dataset.extent(0), static_cast<int64_t>(extended_graph_degree));
}
Expand All @@ -1427,16 +1435,21 @@ void build(raft::resources const& res,
.termination_threshold = params.termination_threshold};

GNND<const T, int> nnd(res, build_config);
nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle(), params.return_distances, distances_graph.data_handle());
nnd.build(dataset.data_handle(),
dataset.extent(0),
int_graph.data_handle(),
params.return_distances,
distances_graph.data_handle());

#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) {
for (size_t j = 0; j < graph_degree; j++) {
auto graph = idx.graph().data_handle();
graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j];
if(params.return_distances) {
if (params.return_distances) {
auto dist_graph = idx.distances().data_handle();
dist_graph[i * graph_degree + j] = distances_graph.data_handle()[i * extended_graph_degree + j];
dist_graph[i * graph_degree + j] =
distances_graph.data_handle()[i * extended_graph_degree + j];
}
}
}
Expand Down
10 changes: 8 additions & 2 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,14 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
}

double min_recall = ps.min_recall;
EXPECT_TRUE(eval_neighbours(
indices_naive, indices_NNDescent, distances_naive, distances_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall));
EXPECT_TRUE(eval_neighbours(indices_naive,
indices_NNDescent,
distances_naive,
distances_NNDescent,
ps.n_rows,
ps.graph_degree,
0.001,
min_recall));
}
}

Expand Down
Loading