Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing code that explicitly compares equality of rmm memory resources #2047

Merged
merged 14 commits into from
Dec 9, 2023
Prev Previous commit
Next Next commit
Removing internal managed memory and pool memory resources- we should…
… only be using workspace resource and current device resource internally to the code
  • Loading branch information
cjnolet committed Dec 7, 2023
commit e0e4a306d7d454b80cf3fe251fa273015bc46ac5
15 changes: 4 additions & 11 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@
#include <raft/util/device_atomics.cuh>
#include <raft/util/integer_utils.hpp>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_vector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <thrust/gather.h>
Expand Down Expand Up @@ -970,16 +970,9 @@ void build_hierarchical(const raft::resources& handle,
IdxT n_mesoclusters = std::min(n_clusters, static_cast<IdxT>(std::sqrt(n_clusters) + 0.5));
RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters);

rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
auto pool_guard =
raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size));
if (pool_guard) {
RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes",
mem_per_row * size_t(max_minibatch_size));
}

// Precompute the L2 norm of the dataset if relevant.
const MathT* dataset_norm = nullptr;
Expand All @@ -1006,8 +999,8 @@ void build_hierarchical(const raft::resources& handle,
CounterT;

// build coarse clusters (mesoclusters)
rmm::device_uvector<LabelT> mesocluster_labels_buf(n_rows, stream, &managed_memory);
rmm::device_uvector<CounterT> mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory);
rmm::device_uvector<LabelT> mesocluster_labels_buf(n_rows, stream, device_memory);
rmm::device_uvector<CounterT> mesocluster_sizes_buf(n_mesoclusters, stream, device_memory);
{
rmm::device_uvector<MathT> mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory);
build_clusters(handle,
Expand Down Expand Up @@ -1063,7 +1056,7 @@ void build_hierarchical(const raft::resources& handle,
fine_clusters_nums_max,
cluster_centers,
mapping_op,
&managed_memory,
device_memory,
device_memory);
RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters.");

Expand Down
38 changes: 6 additions & 32 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <raft/core/resource/device_memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <thrust/extrema.h>
#include <thrust/scan.h>
Expand Down Expand Up @@ -1566,13 +1564,7 @@ void extend(raft::resources const& handle,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"Unsupported data type");

rmm::mr::device_memory_resource* device_memory = nullptr;
auto pool_guard = raft::get_pool_memory_resource(device_memory, 1024 * 1024);
if (pool_guard) { RAFT_LOG_DEBUG("ivf_pq::extend: using pool memory resource"); }

rmm::mr::managed_memory_resource managed_memory_upstream;
rmm::mr::pool_memory_resource<rmm::mr::managed_memory_resource> managed_memory(
&managed_memory_upstream, 1024 * 1024);
rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(handle);

// The spec defines how the clusters look like
auto spec = list_spec<uint32_t, IdxT>{
Expand All @@ -1590,14 +1582,8 @@ void extend(raft::resources const& handle,
size_t free_mem, total_mem;
RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem));

// Decide on an approximate threshold when we'd better start saving device memory by using
// managed allocations for large device buffers
rmm::mr::device_memory_resource* labels_mr = device_memory;
rmm::mr::device_memory_resource* batches_mr = device_memory;
if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) >
free_mem) {
labels_mr = &managed_memory;
}
// Allocate a buffer for the new labels (classifying the new data)
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, labels_mr);
if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; }
Expand Down Expand Up @@ -1631,7 +1617,7 @@ void extend(raft::resources const& handle,
}
if (size_factor * max_batch_size > free_mem) {
// if that still doesn't fit, resort to the UVM
batches_mr = &managed_memory;
batches_mr = device_memory;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if-else now feels a little redundant, doesn't it?

max_batch_size = kReasonableMaxBatchSize;
} else {
// If we're keeping the batches in device memory, update the available mem tracker.
Expand Down Expand Up @@ -1777,22 +1763,10 @@ auto build(raft::resources const& handle,
size_t n_rows_train = n_rows / trainset_ratio;

auto* device_memory = resource::get_workspace_resource(handle);
rmm::mr::managed_memory_resource managed_memory_upstream;
rmm::mr::pool_memory_resource<rmm::mr::managed_memory_resource> managed_memory(
&managed_memory_upstream, 1024 * 1024);

// If the trainset is small enough to comfortably fit into device memory, put it there.
// Otherwise, use the managed memory.
constexpr size_t kTolerableRatio = 4;
rmm::mr::device_memory_resource* big_memory_resource = &managed_memory;
if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio <
resource::get_workspace_free_bytes(handle)) {
big_memory_resource = device_memory;
}

// Besides just sampling, we transform the input dataset into floats to make it easier
// to use gemm operations from cublas.
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, big_memory_resource);
rmm::device_uvector<float> trainset(n_rows_train * index.dim(), stream, device_memory);
// TODO: a proper sampling
if constexpr (std::is_same_v<T, float>) {
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
Expand Down Expand Up @@ -1890,7 +1864,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory);
device_memory);
break;
case codebook_gen::PER_CLUSTER:
train_per_cluster(handle,
Expand All @@ -1899,7 +1873,7 @@ auto build(raft::resources const& handle,
trainset.data(),
labels.data(),
params.kmeans_n_iters,
&managed_memory);
device_memory);
break;
default: RAFT_FAIL("Unreachable code");
}
Expand Down
Loading