Skip to content

Commit

Permalink
Use slicing kernel to copy distances inside NN Descent (#2380)
Browse files Browse the repository at this point in the history
This make use of raft's slicing kernel within NN Descent build.
I found that my previous implementation was inefficient (merged in [this PR](#2345)).

### Improvements
Time to call NN Descent `build()` with `return_distances=True` before and after using this kernel
| Dataset | Before |After|
| ------------- | ------------- |---|
| mnist (60000, 784)  | 1550ms  | 1020ms|
| sift (1M, 128)  | 11342ms  |5546ms|
| gist (1M, 960)  | 13508ms |9278ms|

Authors:
  - Jinsol Park (https://github.com/jinsolp)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Micka (https://github.com/lowener)

URL: #2380
  • Loading branch information
jinsolp authored Jul 22, 2024
1 parent 1ec2e35 commit 706eb39
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
#include "../nn_descent_types.hpp"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/matrix/slice.cuh>
#include <raft/neighbors/detail/cagra/device_common.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
Expand Down Expand Up @@ -1365,12 +1368,22 @@ void GNND<Data_t, Index_t, epilogue_op>::build(Data_t* data,
static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t));
if (return_distances) {
for (size_t i = 0; i < (size_t)nrow_; i++) {
raft::copy(output_distances + i * build_config_.output_graph_degree,
graph_.h_dists.data_handle() + i * build_config_.node_degree,
build_config_.output_graph_degree,
raft::resource::get_cuda_stream(res));
}
auto graph_d_dists = raft::make_device_matrix<DistData_t, int64_t, raft::row_major>(
res, nrow_, build_config_.node_degree);
raft::copy(graph_d_dists.data_handle(),
graph_.h_dists.data_handle(),
nrow_ * build_config_.node_degree,
raft::resource::get_cuda_stream(res));
auto output_dist_view = raft::make_device_matrix_view<DistData_t, int64_t, raft::row_major>(
output_distances, nrow_, build_config_.output_graph_degree);
raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(nrow_),
static_cast<int64_t>(build_config_.output_graph_degree)};
raft::matrix::slice<DistData_t, int64_t, raft::row_major>(
res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords);
}
Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle();
Expand Down

0 comments on commit 706eb39

Please sign in to comment.