Skip to content

Commit

Permalink
Removing code that explicitly compares equality of rmm memory resourc…
Browse files Browse the repository at this point in the history
…es (#2047)

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #2047
  • Loading branch information
cjnolet authored Dec 9, 2023
1 parent 5e80c1d commit addb059
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 145 deletions.
8 changes: 2 additions & 6 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include <raft/util/device_atomics.cuh>
#include <raft/util/integer_utils.hpp>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_vector.hpp>
Expand Down Expand Up @@ -970,16 +971,11 @@ void build_hierarchical(const raft::resources& handle,
IdxT n_mesoclusters = std::min(n_clusters, static_cast<IdxT>(std::sqrt(n_clusters) + 0.5));
RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters);

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
auto pool_guard =
raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size));
if (pool_guard) {
RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes",
mem_per_row * size_t(max_minibatch_size));
}

// Precompute the L2 norm of the dataset if relevant.
const MathT* dataset_norm = nullptr;
Expand Down
58 changes: 0 additions & 58 deletions cpp/include/raft/core/resource/detail/device_memory_resource.hpp

This file was deleted.

1 change: 1 addition & 0 deletions cpp/include/raft/matrix/detail/select_k-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ void select_k(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
auto stream = raft::resource::get_cuda_stream(handle);
auto algo = choose_select_k_algorithm(batch_size, len, k);

Expand Down
14 changes: 2 additions & 12 deletions cpp/include/raft/matrix/detail/select_radix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ void radix_topk(const T* in,
static_assert(calc_num_passes<T, BitsPerPass>() > 1);
constexpr int num_buckets = calc_num_buckets<BitsPerPass>();

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }

auto kernel = radix_kernel<T, IdxT, BitsPerPass, BlockSize, false>;
const size_t max_chunk_size =
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel, false);
Expand All @@ -843,15 +845,7 @@ void radix_topk(const T* in,
}
const IdxT buf_len = calc_buf_len<T>(len);

size_t req_aux = max_chunk_size * (sizeof(Counter<T, IdxT>) + num_buckets * sizeof(IdxT));
size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT));
size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment

auto pool_guard = raft::get_pool_memory_resource(mr, mem_req);
if (pool_guard) {
RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes",
mem_req);
}

rmm::device_uvector<Counter<T, IdxT>> counters(max_chunk_size, stream, mr);
rmm::device_uvector<IdxT> histograms(max_chunk_size * num_buckets, stream, mr);
Expand Down Expand Up @@ -1120,10 +1114,6 @@ void radix_topk_one_block(const T* in,
const size_t max_chunk_size =
calc_chunk_size<T, IdxT, BlockSize>(batch_size, len, sm_cnt, kernel, true);

auto pool_guard =
raft::get_pool_memory_resource(mr, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)));
if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource"); }

rmm::device_uvector<char> bufs(
max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr);

Expand Down
4 changes: 1 addition & 3 deletions cpp/include/raft/matrix/detail/select_warpsort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,7 @@ void select_k_(int num_of_block,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
{
auto pool_guard = raft::get_pool_memory_resource(
mr, num_of_block * k * batch_size * 2 * std::max(sizeof(T), sizeof(IdxT)));
if (pool_guard) { RAFT_LOG_DEBUG("warpsort::select_k: using pool memory resource"); }
if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }

rmm::device_uvector<T> tmp_val(num_of_block * k * batch_size, stream, mr);
rmm::device_uvector<IdxT> tmp_idx(num_of_block * k * batch_size, stream, mr);
Expand Down
6 changes: 1 addition & 5 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

Expand All @@ -48,7 +47,6 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::build");
RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded,
"Currently only L2Expanded metric is supported");

Expand Down Expand Up @@ -125,9 +123,7 @@ void build_knn_graph(raft::resources const& res,
bool first = true;
const auto start_clock = std::chrono::system_clock::now();

rmm::mr::device_memory_resource* device_memory = nullptr;
auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024);
if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq using pool memory resource"); }
rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(res);

raft::spatial::knn::detail::utils::batch_load_iterator<DataT> vec_batches(
dataset.data_handle(),
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>
Expand Down Expand Up @@ -110,7 +109,6 @@ void search_main(raft::resources const& res,
raft::device_matrix_view<DistanceT, int64_t, row_major> distances,
CagraSampleFilterT sample_filter = CagraSampleFilterT())
{
resource::detail::warn_non_pool_workspace(res, "raft::neighbors::cagra::search");
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
static_cast<size_t>(index.dataset().extent(1)));
Expand Down
7 changes: 1 addition & 6 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ inline void search(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim());

if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); }
RAFT_EXPECTS(params.n_probes > 0,
"n_probes (number of clusters to probe in the search) must be positive.");
auto n_probes = std::min<uint32_t>(params.n_probes, index.n_lists());
Expand All @@ -233,12 +234,6 @@ inline void search(raft::resources const& handle,
raft::div_rounding_up_safe<uint64_t>(
kExpectedWsSize, 16ull * uint64_t{n_probes} * k + 4ull * index.dim()));

auto pool_guard = raft::get_pool_memory_resource(mr, max_queries * n_probes * k * 16);
if (pool_guard) {
RAFT_LOG_DEBUG("ivf_flat::search: using pool memory resource with initial size %zu bytes",
n_queries * n_probes * k * 16ull);
}

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);

Expand Down
62 changes: 13 additions & 49 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/neighbors/detail/ivf_pq_codepacking.cuh>
Expand All @@ -29,7 +28,6 @@
#include <raft/core/logger.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/add.cuh>
Expand All @@ -48,11 +46,10 @@
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <thrust/extrema.h>
#include <thrust/scan.h>
Expand Down Expand Up @@ -1559,7 +1556,6 @@ void extend(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::extend(%zu, %u)", size_t(n_rows), index->dim());

resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::extend");
auto stream = resource::get_cuda_stream(handle);
const auto n_clusters = index->n_lists();

Expand All @@ -1569,13 +1565,7 @@ void extend(raft::resources const& handle,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"Unsupported data type");

rmm::mr::device_memory_resource* device_memory = nullptr;
auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024);
if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq::extend: using pool memory resource"); }

rmm::mr::managed_memory_resource managed_memory_upstream;
rmm::mr::pool_memory_resource<rmm::mr::managed_memory_resource> managed_memory(
&managed_memory_upstream, 1024 * 1024);
rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(handle);

// The spec defines how the clusters look like
auto spec = list_spec<uint32_t, IdxT>{
Expand All @@ -1593,17 +1583,9 @@ void extend(raft::resources const& handle,
size_t free_mem, total_mem;
RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem));

// Decide on an approximate threshold when we'd better start saving device memory by using
// managed allocations for large device buffers
rmm::mr::device_memory_resource* labels_mr = device_memory;
rmm::mr::device_memory_resource* batches_mr = device_memory;
if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) >
free_mem) {
labels_mr = &managed_memory;
}
// Allocate a buffer for the new labels (classifying the new data)
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, labels_mr);
if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; }
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, device_memory);
free_mem -= sizeof(uint32_t) * n_rows;

// Calculate the batch size for the input data if it's not accessible directly from the device
constexpr size_t kReasonableMaxBatchSize = 65536;
Expand Down Expand Up @@ -1632,19 +1614,13 @@ void extend(raft::resources const& handle,
while (size_factor * max_batch_size > free_mem && max_batch_size > 128) {
max_batch_size >>= 1;
}
if (size_factor * max_batch_size > free_mem) {
// if that still doesn't fit, resort to the UVM
batches_mr = &managed_memory;
max_batch_size = kReasonableMaxBatchSize;
} else {
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
}
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
}

// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(
new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr);
new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory);
// Release the placeholder memory, because we don't intend to allocate any more long-living
// temporary buffers before we allocate the index data.
// This memory could potentially speed up UVM accesses, if any.
Expand Down Expand Up @@ -1717,7 +1693,7 @@ void extend(raft::resources const& handle,
// By this point, the index state is updated and valid except it doesn't contain the new data
// Fill the extended index with the new data (possibly, in batches)
utils::batch_load_iterator<IdxT> idx_batches(
new_indices, n_rows, 1, max_batch_size, stream, batches_mr);
new_indices, n_rows, 1, max_batch_size, stream, device_memory);
for (const auto& vec_batch : vec_batches) {
const auto& idx_batch = *idx_batches++;
process_and_fill_codes(handle,
Expand All @@ -1728,7 +1704,7 @@ void extend(raft::resources const& handle,
: std::variant<IdxT, const IdxT*>(IdxT(idx_batch.offset())),
new_data_labels.data() + vec_batch.offset(),
IdxT(vec_batch.size()),
batches_mr);
device_memory);
}
}

Expand Down Expand Up @@ -1758,7 +1734,6 @@ auto build(raft::resources const& handle,
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_pq::build(%zu, %u)", size_t(n_rows), dim);
resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::build");
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"Unsupported data type");

Expand All @@ -1782,21 +1757,10 @@ auto build(raft::resources const& handle,

auto* device_memory = resource::get_workspace_resource(handle);
rmm::mr::managed_memory_resource managed_memory_upstream;
rmm::mr::pool_memory_resource<rmm::mr::managed_memory_resource> managed_memory(
&managed_memory_upstream, 1024 * 1024);

// If the trainset is small enough to comfortably fit into device memory, put it there.
// Otherwise, use the managed memory.
constexpr size_t kTolerableRatio = 4;
rmm::mr::device_memory_resource* big_memory_resource = &managed_memory;
if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio <
resource::get_workspace_free_bytes(handle)) {
big_memory_resource = device_memory;
}

// Besides just sampling, we transform the input dataset into floats to make it easier
// to use gemm operations from cublas.
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, big_memory_resource);
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, device_memory);
// TODO: a proper sampling
if constexpr (std::is_same_v<T, float>) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
Expand Down Expand Up @@ -1865,7 +1829,7 @@ auto build(raft::resources const& handle,
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});

// Trainset labels are needed for training PQ codebooks
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, big_memory_resource);
rmm::device_uvector<uint32_t> labels(n_rows_train, stream, device_memory);
auto centers_const_view = raft::make_device_matrix_view<const float, IdxT>(
cluster_centers, index.n_lists(), index.dim());
auto labels_view = raft::make_device_vector_view<uint32_t, IdxT>(labels.data(), n_rows_train);
Expand Down Expand Up @@ -1894,7 +1858,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory);
&managed_memory_upstream);
break;
case codebook_gen::PER_CLUSTER:
train_per_cluster(handle,
Expand All @@ -1903,7 +1867,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory);
&managed_memory_upstream);
break;
default: RAFT_FAIL("Unreachable code");
}
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <raft/core/logger.hpp>
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/detail/device_memory_resource.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -747,7 +746,6 @@ inline void search(raft::resources const& handle,
params.n_probes,
k,
index.dim());
resource::detail::warn_non_pool_workspace(handle, "raft::ivf_pq::search");

RAFT_EXPECTS(
params.internal_distance_dtype == CUDA_R_16F || params.internal_distance_dtype == CUDA_R_32F,
Expand Down
Loading

0 comments on commit addb059

Please sign in to comment.