Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random subsampling for IVF methods #2077

Merged
merged 11 commits into from
Jan 23, 2024
Prev Previous commit
Next Next commit
Fix nvtx markers
  • Loading branch information
tfeher committed Jan 21, 2024
commit 790c0e6e267bc0501639cf9e436fd0c94dcd5582
12 changes: 5 additions & 7 deletions cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <functional>
#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
Expand All @@ -26,6 +27,7 @@
#include <raft/core/pinned_mdspan.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft {
namespace matrix {
namespace detail {
Expand Down Expand Up @@ -347,8 +349,7 @@ void gather_buff(host_matrix_view<const T, IdxT> dataset,
IdxT offset,
pinned_matrix_view<T, IdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("Gather vectors");

raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather_host_buff");
IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);

#pragma omp for
Expand All @@ -366,6 +367,7 @@ void gather(raft::resources const& res,
device_vector_view<const IdxT, IdxT> indices,
raft::device_matrix_view<T, IdxT> output)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("gather");
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
auto indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
Expand All @@ -376,11 +378,7 @@ void gather(raft::resources const& res,
const size_t max_batch_size = 32768;
// Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync
// and gathering the data.
raft::common::nvtx::push_range("subsample::alloc_buffers");
// rmm::mr::pinned_memory_resource mr_pinned;
// auto out_tmp1 = make_host_mdarray<T>(res, mr_pinned, make_extents<IdxT>(max_batch_size,
// n_dim)); auto out_tmp2 = make_host_mdarray<T>(res, mr_pinned,
// make_extents<IdxT>(max_batch_size, n_dim));
raft::common::nvtx::push_range("gather::alloc_buffers");
auto out_tmp1 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto out_tmp2 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto view1 = out_tmp1.view();
Expand Down