Skip to content

Commit

Permalink
ivf_flat::index: hide implementation details (#747)
Browse files Browse the repository at this point in the history
Hide the mutable `mdarray` members of `ivf_flat::index` behind immutable `mdspan` views.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #747
  • Loading branch information
achirkin authored Aug 24, 2022
1 parent 85bbbab commit ab9a695
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 218 deletions.
13 changes: 10 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,16 @@ constexpr auto calc_minibatch_size(uint32_t n_clusters, size_t n_rows) -> uint32
/**
* @brief Given the data and labels, calculate cluster centers and sizes in one sweep.
*
* Let S_i = {x_k | x_k \in dataset & labels[k] == i} be the vectors in the dataset with label i.
* On exit centers_i = normalize(\sum_{x \in S_i} x), where `normalize` depends on the distance
* type.
* Let `S_i = {x_k | x_k \in dataset & labels[k] == i}` be the vectors in the dataset with label i.
*
* On exit,
* `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`,
* where `w_i = reset_counters ? 0 : cluster_size[i]`.
*
* In other words, the updated cluster centers are a weighted average of the existing cluster
* center, and the coordinates of the points labeled with i. _This allows calling this function
* multiple times with different datasets with the same effect as if calling this function once
* on the combined dataset_.
*
* NB: `centers` and `cluster_sizes` must be accessible on GPU due to
* divide_along_rows/normalize_rows. The rest can be both, under assumption that all pointers are
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -368,33 +368,33 @@ __global__ void map_along_rows_kernel(
}

/**
* @brief Divide matrix values along rows by an integer value, skipping rows if the corresponding
* divisor is zero.
* @brief Map a binary function over a matrix and a vector element-wise, broadcasting the vector
* values along rows: `m[i, j] = op(m[i,j], v[i])`
*
* NB: device-only function
*
* @tparam Lambda
*
* @param n_rows
* @param n_cols
* @param[inout] a device pointer to a row-major matrix [n_rows, n_cols]
* @param[in] d device pointer to a vector [n_rows]
* @param map the binary operation to apply on every element of matrix rows and of the vector
* @param[inout] m device pointer to a row-major matrix [n_rows, n_cols]
* @param[in] v device pointer to a vector [n_rows]
* @param op the binary operation to apply on every element of matrix rows and of the vector
*/
template <typename Lambda>
inline void map_along_rows(uint32_t n_rows,
uint32_t n_cols,
float* a,
const uint32_t* d,
Lambda map,
float* m,
const uint32_t* v,
Lambda op,
rmm::cuda_stream_view stream)
{
dim3 threads(128, 1, 1);
dim3 blocks(
ceildiv<uint64_t>(static_cast<uint64_t>(n_rows) * static_cast<uint64_t>(n_cols), threads.x),
1,
1);
map_along_rows_kernel<<<blocks, threads, 0, stream>>>(n_rows, n_cols, a, d, map);
map_along_rows_kernel<<<blocks, threads, 0, stream>>>(n_rows, n_cols, m, v, op);
}

template <typename T>
Expand Down
146 changes: 48 additions & 98 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,39 +108,38 @@ inline auto extend(const handle_t& handle,
const index<T, IdxT>& orig_index,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows,
rmm::cuda_stream_view stream) -> index<T, IdxT>
IdxT n_rows) -> index<T, IdxT>
{
auto n_lists = orig_index.n_lists;
auto dim = orig_index.dim;
auto stream = handle.get_stream();
auto n_lists = orig_index.n_lists();
auto dim = orig_index.dim();
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::extend(%zu, %u)", size_t(n_rows), dim);

RAFT_EXPECTS(new_indices != nullptr || orig_index.size == 0,
RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0,
"You must pass data indices when the index is non-empty.");

rmm::device_uvector<uint32_t> new_labels(n_rows, stream);
kmeans::predict(handle,
orig_index.centers.data(),
orig_index.centers().data_handle(),
n_lists,
dim,
new_vectors,
n_rows,
new_labels.data(),
orig_index.metric,
orig_index.metric(),
stream);

auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
auto list_sizes_ptr = list_sizes.data();
auto list_offsets_ptr = list_offsets.data();
index<T, IdxT> ext_index(handle, orig_index.metric(), n_lists, dim);

auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);
auto centers_ptr = centers.data();
auto list_sizes_ptr = ext_index.list_sizes().data_handle();
auto list_offsets_ptr = ext_index.list_offsets().data_handle();
auto centers_ptr = ext_index.centers().data_handle();

// Calculate the centers and sizes on the new data, starting from the original values
raft::copy(centers_ptr, orig_index.centers.data(), centers.size(), stream);
raft::copy(list_sizes_ptr, orig_index.list_sizes.data(), list_sizes.size(), stream);
raft::copy(centers_ptr, orig_index.centers().data_handle(), ext_index.centers().size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

kmeans::calc_centers_and_sizes(centers_ptr,
list_sizes_ptr,
Expand All @@ -160,146 +159,97 @@ inline auto extend(const handle_t& handle,
list_sizes_ptr,
list_sizes_ptr + n_lists,
list_offsets_ptr + 1,
[] __device__(IdxT s, uint32_t l) { return s + Pow2<WarpSize>::roundUp(l); });
[] __device__(IdxT s, uint32_t l) { return s + Pow2<kIndexGroupSize>::roundUp(l); });
update_host(&index_size, list_offsets_ptr + n_lists, 1, stream);
handle.sync_stream(stream);

auto&& data = rmm::device_uvector<T>(index_size * IdxT(dim), stream);
auto&& indices = rmm::device_uvector<IdxT>(index_size, stream);
ext_index.allocate(
handle, index_size, ext_index.metric() == raft::distance::DistanceType::L2Expanded);

// Populate index with the old data
if (orig_index.size > 0) {
utils::block_copy(orig_index.list_offsets.data(),
if (orig_index.size() > 0) {
utils::block_copy(orig_index.list_offsets().data_handle(),
list_offsets_ptr,
IdxT(n_lists),
orig_index.data.data(),
data.data(),
orig_index.data().data_handle(),
ext_index.data().data_handle(),
IdxT(dim),
stream);

utils::block_copy(orig_index.list_offsets.data(),
utils::block_copy(orig_index.list_offsets().data_handle(),
list_offsets_ptr,
IdxT(n_lists),
orig_index.indices.data(),
indices.data(),
orig_index.indices().data_handle(),
ext_index.indices().data_handle(),
IdxT(1),
stream);
}

// Copy the old sizes, so we can start from the current state of the index;
// we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter.
raft::copy(list_sizes_ptr, orig_index.list_sizes.data(), list_sizes.size(), stream);
raft::copy(
list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream);

const dim3 block_dim(256);
const dim3 grid_dim(raft::ceildiv<IdxT>(n_rows, block_dim.x));
build_index_kernel<<<grid_dim, block_dim, 0, stream>>>(new_labels.data(),
list_offsets_ptr,
new_vectors,
new_indices,
data.data(),
indices.data(),
ext_index.data().data_handle(),
ext_index.indices().data_handle(),
list_sizes_ptr,
n_rows,
dim,
orig_index.veclen);
ext_index.veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());

// Precompute the centers vector norms for L2Expanded distance
auto compute_norms = [&]() {
auto&& r = rmm::device_uvector<float>(n_lists, stream);
utils::dots_along_rows(n_lists, dim, centers.data(), r.data(), stream);
RAFT_LOG_TRACE_VEC(r.data(), 20);
return std::move(r);
};
auto&& center_norms = orig_index.metric == raft::distance::DistanceType::L2Expanded
? std::optional(compute_norms())
: std::nullopt;
if (ext_index.center_norms().has_value()) {
utils::dots_along_rows(n_lists,
dim,
ext_index.centers().data_handle(),
ext_index.center_norms()->data_handle(),
stream);
RAFT_LOG_TRACE_VEC(ext_index.center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}

// assemble the index
index<T, IdxT> new_index{{},
orig_index.veclen,
orig_index.metric,
index_size,
orig_index.dim,
orig_index.n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
std::move(list_offsets),
std::move(centers),
std::move(center_norms)};

// check index invariants
new_index.check_consistency();

return new_index;
return ext_index;
}

/** See raft::spatial::knn::ivf_flat::build docs */
template <typename T, typename IdxT>
inline auto build(const handle_t& handle,
const index_params& params,
const T* dataset,
IdxT n_rows,
uint32_t dim,
rmm::cuda_stream_view stream) -> index<T, IdxT>
inline auto build(
const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim)
-> index<T, IdxT>
{
auto stream = handle.get_stream();
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::build(%zu, %u)", size_t(n_rows), dim);
static_assert(std::is_same_v<T, float> || std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>,
"unsupported data type");
RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset");

// TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a
// template parameter (https://github.com/rapidsai/raft/issues/711)
uint32_t veclen = 16 / sizeof(T);
while (dim % veclen != 0) {
veclen = veclen >> 1;
}
auto n_lists = static_cast<uint32_t>(params.n_lists);

// kmeans cluster ids for the dataset
auto&& centers = rmm::device_uvector<float>(size_t(n_lists) * size_t(dim), stream);
index<T, IdxT> index(handle, params, dim);
utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream);
utils::memzero(index.list_offsets().data_handle(), index.list_offsets().size(), stream);

// Predict labels of the whole dataset
kmeans::build_optimized_kmeans(handle,
params.kmeans_n_iters,
dim,
dataset,
n_rows,
centers.data(),
n_lists,
index.centers().data_handle(),
params.n_lists,
params.kmeans_trainset_fraction,
params.metric,
stream);

auto&& data = rmm::device_uvector<T>(0, stream);
auto&& indices = rmm::device_uvector<IdxT>(0, stream);
auto&& list_sizes = rmm::device_uvector<uint32_t>(n_lists, stream);
auto&& list_offsets = rmm::device_uvector<IdxT>(n_lists + 1, stream);
utils::memzero(list_sizes.data(), list_sizes.size(), stream);
utils::memzero(list_offsets.data(), list_offsets.size(), stream);

// assemble the index
index<T, IdxT> index{{},
veclen,
params.metric,
IdxT(0),
dim,
n_lists,
std::move(data),
std::move(indices),
std::move(list_sizes),
std::move(list_offsets),
std::move(centers),
std::nullopt};

// check index invariants
index.check_consistency();

// add the data if necessary
if (params.add_data_on_build) {
return extend<T, IdxT>(handle, index, dataset, nullptr, n_rows, stream);
return detail::extend<T, IdxT>(handle, index, dataset, nullptr, n_rows);
} else {
return index;
}
Expand Down
Loading

0 comments on commit ab9a695

Please sign in to comment.