Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing code that explicitly compares equality of rmm memory resources #2047

Merged
merged 14 commits into from
Dec 9, 2023
Prev Previous commit
Next Next commit
Even more fixes...
  • Loading branch information
cjnolet committed Dec 7, 2023
commit 6b583244d35734ecd35cf93d8c3624be6131b5c4
22 changes: 7 additions & 15 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1582,11 +1582,9 @@ void extend(raft::resources const& handle,
size_t free_mem, total_mem;
RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem));

rmm::mr::device_memory_resource* labels_mr = device_memory;
rmm::mr::device_memory_resource* batches_mr = device_memory;
// Allocate a buffer for the new labels (classifying the new data)
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, labels_mr);
if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; }
rmm::device_uvector<uint32_t> new_data_labels(n_rows, stream, device_memory);
free_mem -= sizeof(uint32_t) * n_rows;

// Calculate the batch size for the input data if it's not accessible directly from the device
constexpr size_t kReasonableMaxBatchSize = 65536;
Expand Down Expand Up @@ -1615,19 +1613,13 @@ void extend(raft::resources const& handle,
while (size_factor * max_batch_size > free_mem && max_batch_size > 128) {
max_batch_size >>= 1;
}
if (size_factor * max_batch_size > free_mem) {
// if that still doesn't fit, resort to the UVM
batches_mr = device_memory;
max_batch_size = kReasonableMaxBatchSize;
} else {
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
}
// If we're keeping the batches in device memory, update the available mem tracker.
free_mem -= size_factor * max_batch_size;
}

// Predict the cluster labels for the new data, in batches if necessary
utils::batch_load_iterator<T> vec_batches(
new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr);
new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory);
// Release the placeholder memory, because we don't intend to allocate any more long-living
// temporary buffers before we allocate the index data.
// This memory could potentially speed up UVM accesses, if any.
Expand Down Expand Up @@ -1700,7 +1692,7 @@ void extend(raft::resources const& handle,
// By this point, the index state is updated and valid except it doesn't contain the new data
// Fill the extended index with the new data (possibly, in batches)
utils::batch_load_iterator<IdxT> idx_batches(
new_indices, n_rows, 1, max_batch_size, stream, batches_mr);
new_indices, n_rows, 1, max_batch_size, stream, device_memory);
for (const auto& vec_batch : vec_batches) {
const auto& idx_batch = *idx_batches++;
process_and_fill_codes(handle,
Expand All @@ -1711,7 +1703,7 @@ void extend(raft::resources const& handle,
: std::variant<IdxT, const IdxT*>(IdxT(idx_batch.offset())),
new_data_labels.data() + vec_batch.offset(),
IdxT(vec_batch.size()),
batches_mr);
device_memory);
}
}

Expand Down
Loading