Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random subsampling for IVF methods #2077

Merged
merged 11 commits into from
Jan 23, 2024
Prev Previous commit
Next Next commit
Batched index gather overlapped with H2D copies
  • Loading branch information
tfeher committed Jan 21, 2024
commit 548555766b9acb485000c24e32816d6d874f58b5
76 changes: 75 additions & 1 deletion cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
#pragma once

#include <functional>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/pinned_mdarray.hpp>
#include <raft/core/pinned_mdspan.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft {
namespace matrix {
namespace detail {
Expand Down Expand Up @@ -335,6 +341,74 @@ void gather_if(const InputIteratorT in,
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

template <typename T, typename IdxT = int64_t>
void gather_buff(host_matrix_view<const T, IdxT> dataset,
host_vector_view<const IdxT, IdxT> indices,
IdxT offset,
pinned_matrix_view<T, IdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("Gather vectors");

IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);

#pragma omp for
for (IdxT i = 0; i < batch_size; i++) {
IdxT in_idx = indices(offset + i);
for (IdxT k = 0; k < buff.extent(1); k++) {
buff(i, k) = dataset(in_idx, k);
}
}
}

template <typename T, typename IdxT>
void gather(raft::resources const& res,
host_matrix_view<const T, IdxT> dataset,
device_vector_view<const IdxT, IdxT> indices,
raft::device_matrix_view<T, IdxT> output)
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
auto indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(
indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res));
resource::sync_stream(res);

const size_t max_batch_size = 32768;
// Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync
// and gathering the data.
raft::common::nvtx::push_range("subsample::alloc_buffers");
// rmm::mr::pinned_memory_resource mr_pinned;
// auto out_tmp1 = make_host_mdarray<T>(res, mr_pinned, make_extents<IdxT>(max_batch_size,
// n_dim)); auto out_tmp2 = make_host_mdarray<T>(res, mr_pinned,
// make_extents<IdxT>(max_batch_size, n_dim));
auto out_tmp1 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto out_tmp2 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto view1 = out_tmp1.view();
auto view2 = out_tmp2.view();
raft::common::nvtx::pop_range();

gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1);
#pragma omp parallel
for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) {
IdxT batch_size = std::min<IdxT>(max_batch_size, n_train - device_offset);
#pragma omp master
raft::copy(output.data_handle() + device_offset * n_dim,
view1.data_handle(),
batch_size * n_dim,
resource::get_cuda_stream(res));
// Start gathering the next batch on the host.
IdxT host_offset = device_offset + batch_size;
batch_size = std::min<IdxT>(max_batch_size, n_train - host_offset);
if (batch_size > 0) {
gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
}
#pragma omp master
resource::sync_stream(res);
#pragma omp barrier
std::swap(view1, view2);
}
}

} // namespace detail
} // namespace matrix
} // namespace raft
43 changes: 10 additions & 33 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

#pragma once

#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>

#include <raft/core/logger.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/copy.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -601,10 +603,6 @@ auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_su
std::nullopt,
train_indices.view(),
std::nullopt);

thrust::sort(resource::get_thrust_policy(res),
train_indices.data_handle(),
train_indices.data_handle() + n_subsamples);
return train_indices;
}

Expand All @@ -618,42 +616,21 @@ void subsample(raft::resources const& res,
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
if (n_train == n_samples) {
RAFT_LOG_INFO("No subsampling");
raft::copy(output.data_handle(), input, n_dim * n_samples, resource::get_cuda_stream(res));
return;
}
RAFT_LOG_DEBUG("Random subsampling");

raft::device_vector<IdxT, IdxT> train_indices =
get_subsample_indices<IdxT>(res, n_samples, n_train, seed);

cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input));
T* ptr = reinterpret_cast<T*>(attr.devicePointer);
if (ptr != nullptr) {
raft::matrix::copy_rows(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
output,
raft::make_const_mdspan(train_indices.view()));
raft::matrix::gather(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
raft::make_const_mdspan(train_indices.view()),
output);
} else {
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
auto train_indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(train_indices_host.data_handle(),
train_indices.data_handle(),
n_train,
resource::get_cuda_stream(res));
resource::sync_stream(res);
auto out_tmp = raft::make_host_matrix<T, IdxT>(n_train, n_dim);
#pragma omp parallel for
for (IdxT i = 0; i < n_train; i++) {
IdxT in_idx = train_indices_host(i);
for (IdxT k = 0; k < n_dim; k++) {
out_tmp(i, k) = dataset(in_idx, k);
}
}
raft::copy(
output.data_handle(), out_tmp.data_handle(), output.size(), resource::get_cuda_stream(res));
resource::sync_stream(res);
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output);
}
}
} // namespace raft::spatial::knn::detail::utils
Loading