Skip to content

Commit

Permalink
matrix::select_k: move selection and warp-sort primitives (#1085)
Browse files Browse the repository at this point in the history
Refactor and move a set of implementations for batch-selecting top K largest/smallest values:

  - Move device warp-wide primitives `bitonic_sort.cuh` to the public `raft::util` namespace, add tests.
  - Create a new public `matrix::select_k` interface.
  - Deprecate the legacy public `raft::spatial::knn::select_k` interface.
  - Copy/adapt `select_k` tests.
  - Move/adapt `select_k` benchmarks.
  - Rework the internals of `select_warpsort.cuh` to enable more implementations.

Closes #853

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1085
  • Loading branch information
achirkin authored Jan 23, 2023
1 parent 5a6cb09 commit 0076101
Show file tree
Hide file tree
Showing 21 changed files with 1,631 additions and 476 deletions.
6 changes: 4 additions & 2 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ if(BUILD_BENCH)
bench/main.cpp
)

ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp)
ConfigureBench(
NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu
bench/main.cpp
)

ConfigureBench(
NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu
Expand All @@ -127,7 +130,6 @@ if(BUILD_BENCH)
bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu
bench/neighbors/refine.cu
bench/neighbors/selection.cu
bench/main.cpp
OPTIONAL
DIST
Expand Down
133 changes: 133 additions & 0 deletions cpp/bench/matrix/select_k.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* TODO: reconsider how to organize shared test+bench files better
* Related Issue: https://github.com/rapidsai/raft/issues/1153
* (although this header does not depend on any gtest headers)
*/
#include "../../test/matrix/select_k.cuh"

#include <common/benchmark.hpp>

#include <raft/core/handle.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>
#include <raft/util/cudart_utils.hpp>

#include <raft/matrix/detail/select_radix.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/matrix/select_k.cuh>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace raft::matrix {

using namespace raft::bench; // NOLINT

template <typename KeyT, typename IdxT, select::Algo Algo>
struct selection : public fixture {
explicit selection(const select::params& p)
: params_(p),
in_dists_(p.batch_size * p.len, stream),
in_ids_(p.batch_size * p.len, stream),
out_dists_(p.batch_size * p.k, stream),
out_ids_(p.batch_size * p.k, stream)
{
raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream);
raft::random::RngState state{42};
raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0));
}

void run_benchmark(::benchmark::State& state) override // NOLINT
{
handle_t handle{stream};
using_pool_memory_res res;
try {
std::ostringstream label_stream;
label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k;
state.SetLabel(label_stream.str());
loop_on_state(state, [this, &handle]() {
select::select_k_impl<KeyT, IdxT>(handle,
Algo,
in_dists_.data(),
in_ids_.data(),
params_.batch_size,
params_.len,
params_.k,
out_dists_.data(),
out_ids_.data(),
params_.select_min);
});
} catch (raft::exception& e) {
state.SkipWithError(e.what());
}
}

private:
const select::params params_;
rmm::device_uvector<KeyT> in_dists_, out_dists_;
rmm::device_uvector<IdxT> in_ids_, out_ids_;
};

const std::vector<select::params> kInputs{
{20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true},
{20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true},
{20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true},

{1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true},
{1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true},
{1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true},

{100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true},
{100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true},
{100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true},

{10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true},
{10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true},
{10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true},
};

#define SELECTION_REGISTER(KeyT, IdxT, A) \
namespace BENCHMARK_PRIVATE_NAME(selection) \
{ \
using SelectK = selection<KeyT, IdxT, select::Algo::A>; \
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, int, kPublicApi); // NOLINT
SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT

} // namespace raft::matrix
123 changes: 0 additions & 123 deletions cpp/bench/neighbors/selection.cu

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,76 +16,76 @@

#pragma once

#include "topk/radix_topk.cuh"
#include "topk/warpsort_topk.cuh"
#include "select_radix.cuh"
#include "select_warpsort.cuh"

#include <raft/core/nvtx.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

namespace raft::spatial::knn::detail {
namespace raft::matrix::detail {

/**
* Select k smallest or largest key/values from each row in the input data.
*
* If you think of the input data `in_keys` as a row-major matrix with len columns and
* batch_size rows, then this function selects k smallest/largest values in each row and fills
* in the row-major matrix `out` of size (batch_size, k).
* If you think of the input data `in_val` as a row-major matrix with `len` columns and
* `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills
* in the row-major matrix `out_val` of size (batch_size, k).
*
* @tparam T
* the type of the keys (what is being compared).
* @tparam IdxT
* the index type (what is being selected together with the keys).
*
* @param[in] in
* @param[in] in_val
* contiguous device array of inputs of size (len * batch_size);
* these are compared and selected.
* @param[in] in_idx
* contiguous device array of inputs of size (len * batch_size);
* typically, these are indices of the corresponding in_keys.
* typically, these are indices of the corresponding in_val.
* @param batch_size
* number of input rows, i.e. the batch size.
* @param len
* length of a single input array (row); also sometimes referred as n_cols.
* Invariant: len >= k.
* @param k
* the number of outputs to select in each input row.
* @param[out] out
* @param[out] out_val
* contiguous device array of outputs of size (k * batch_size);
* the k smallest/largest values from each row of the `in_keys`.
* the k smallest/largest values from each row of the `in_val`.
* @param[out] out_idx
* contiguous device array of outputs of size (k * batch_size);
* the payload selected together with `out`.
* the payload selected together with `out_val`.
* @param select_min
* whether to select k smallest (true) or largest (false) keys.
* @param stream
* @param mr an optional memory resource to use across the calls (you can provide a large enough
* memory pool here to avoid memory allocations within the call).
*/
template <typename T, typename IdxT>
void select_topk(const T* in,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
void select_k(const T* in_val,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out_val,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
"matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k);
// TODO (achirkin): investigate the trade-off for a wider variety of inputs.
const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128;
if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) {
topk::warp_sort_topk<T, IdxT>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr);
if (k <= select::warpsort::kMaxCapacity && !radix_faster) {
select::warpsort::select_k<T, IdxT>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
} else {
topk::radix_topk<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr);
select::radix::select_k<T, IdxT, (sizeof(T) >= 4 ? 11 : 8), 512>(
in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr);
}
}

} // namespace raft::spatial::knn::detail
} // namespace raft::matrix::detail
Loading

0 comments on commit 0076101

Please sign in to comment.