Skip to content

Commit

Permalink
Batched index gather overlapped with H2D copies
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Jan 21, 2024
1 parent 7cfd9b5 commit 5485557
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 34 deletions.
76 changes: 75 additions & 1 deletion cpp/include/raft/matrix/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
#pragma once

#include <functional>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/pinned_mdarray.hpp>
#include <raft/core/pinned_mdspan.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/cudart_utils.hpp>

namespace raft {
namespace matrix {
namespace detail {
Expand Down Expand Up @@ -335,6 +341,74 @@ void gather_if(const InputIteratorT in,
gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}

template <typename T, typename IdxT = int64_t>
void gather_buff(host_matrix_view<const T, IdxT> dataset,
host_vector_view<const IdxT, IdxT> indices,
IdxT offset,
pinned_matrix_view<T, IdxT> buff)
{
raft::common::nvtx::range<common::nvtx::domain::raft> fun_scope("Gather vectors");

IdxT batch_size = std::min<IdxT>(buff.extent(0), indices.extent(0) - offset);

#pragma omp for
for (IdxT i = 0; i < batch_size; i++) {
IdxT in_idx = indices(offset + i);
for (IdxT k = 0; k < buff.extent(1); k++) {
buff(i, k) = dataset(in_idx, k);
}
}
}

template <typename T, typename IdxT>
void gather(raft::resources const& res,
host_matrix_view<const T, IdxT> dataset,
device_vector_view<const IdxT, IdxT> indices,
raft::device_matrix_view<T, IdxT> output)
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
auto indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(
indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res));
resource::sync_stream(res);

const size_t max_batch_size = 32768;
// Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync
// and gathering the data.
raft::common::nvtx::push_range("subsample::alloc_buffers");
// rmm::mr::pinned_memory_resource mr_pinned;
// auto out_tmp1 = make_host_mdarray<T>(res, mr_pinned, make_extents<IdxT>(max_batch_size,
// n_dim)); auto out_tmp2 = make_host_mdarray<T>(res, mr_pinned,
// make_extents<IdxT>(max_batch_size, n_dim));
auto out_tmp1 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto out_tmp2 = raft::make_pinned_matrix<T, IdxT>(res, max_batch_size, n_dim);
auto view1 = out_tmp1.view();
auto view2 = out_tmp2.view();
raft::common::nvtx::pop_range();

gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1);
#pragma omp parallel
for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) {
IdxT batch_size = std::min<IdxT>(max_batch_size, n_train - device_offset);
#pragma omp master
raft::copy(output.data_handle() + device_offset * n_dim,
view1.data_handle(),
batch_size * n_dim,
resource::get_cuda_stream(res));
// Start gathering the next batch on the host.
IdxT host_offset = device_offset + batch_size;
batch_size = std::min<IdxT>(max_batch_size, n_train - host_offset);
if (batch_size > 0) {
gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2);
}
#pragma omp master
resource::sync_stream(res);
#pragma omp barrier
std::swap(view1, view2);
}
}

} // namespace detail
} // namespace matrix
} // namespace raft
43 changes: 10 additions & 33 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

#pragma once

#include <raft/common/nvtx.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>

#include <raft/core/logger.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/matrix/copy.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/random/sample_without_replacement.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -601,10 +603,6 @@ auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_su
std::nullopt,
train_indices.view(),
std::nullopt);

thrust::sort(resource::get_thrust_policy(res),
train_indices.data_handle(),
train_indices.data_handle() + n_subsamples);
return train_indices;
}

Expand All @@ -618,42 +616,21 @@ void subsample(raft::resources const& res,
{
IdxT n_dim = output.extent(1);
IdxT n_train = output.extent(0);
if (n_train == n_samples) {
RAFT_LOG_INFO("No subsampling");
raft::copy(output.data_handle(), input, n_dim * n_samples, resource::get_cuda_stream(res));
return;
}
RAFT_LOG_DEBUG("Random subsampling");

raft::device_vector<IdxT, IdxT> train_indices =
get_subsample_indices<IdxT>(res, n_samples, n_train, seed);

cudaPointerAttributes attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input));
T* ptr = reinterpret_cast<T*>(attr.devicePointer);
if (ptr != nullptr) {
raft::matrix::copy_rows(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
output,
raft::make_const_mdspan(train_indices.view()));
raft::matrix::gather(res,
raft::make_device_matrix_view<const T, IdxT>(ptr, n_samples, n_dim),
raft::make_const_mdspan(train_indices.view()),
output);
} else {
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
auto train_indices_host = raft::make_host_vector<IdxT, IdxT>(n_train);
raft::copy(train_indices_host.data_handle(),
train_indices.data_handle(),
n_train,
resource::get_cuda_stream(res));
resource::sync_stream(res);
auto out_tmp = raft::make_host_matrix<T, IdxT>(n_train, n_dim);
#pragma omp parallel for
for (IdxT i = 0; i < n_train; i++) {
IdxT in_idx = train_indices_host(i);
for (IdxT k = 0; k < n_dim; k++) {
out_tmp(i, k) = dataset(in_idx, k);
}
}
raft::copy(
output.data_handle(), out_tmp.data_handle(), output.size(), resource::get_cuda_stream(res));
resource::sync_stream(res);
auto dataset = raft::make_host_matrix_view<const T, IdxT>(input, n_samples, n_dim);
raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output);
}
}
} // namespace raft::spatial::knn::detail::utils

0 comments on commit 5485557

Please sign in to comment.