Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp committed Jun 6, 2024
1 parent 0327ce5 commit 3f49752
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "../test_utils.cuh"
#include "ann_utils.cuh"
#include "raft/util/cudart_utils.hpp"

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/neighbors/nn_descent.cuh>
Expand Down Expand Up @@ -94,6 +95,7 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
index_params.graph_degree = ps.graph_degree;
index_params.intermediate_graph_degree = 2 * ps.graph_degree;
index_params.max_iterations = 100;
index_params.return_distances = true;

auto database_view = raft::make_device_matrix_view<const DataT, int64_t>(
(const DataT*)database.data(), ps.n_rows, ps.dim);
Expand All @@ -105,16 +107,12 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
auto database_host_view = raft::make_host_matrix_view<const DataT, int64_t>(
(const DataT*)database_host.data_handle(), ps.n_rows, ps.dim);
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_host_view);
update_host(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
update_host(
distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
} else {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_view);
update_host(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
update_host(
distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
raft::copy(indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
raft::copy(distances_NNDescent.data(), index.distances().data_handle(), queries_size, stream_);
};
}
resource::sync_stream(handle_);
Expand Down

0 comments on commit 3f49752

Please sign in to comment.