Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matrix::select_k: move selection and warp-sort primitives #1085

Merged
merged 42 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
39c10a9
Make warp-level bitonic sort public
achirkin Dec 9, 2022
6cda736
Move spatial::*::select_topk to matrix::select_k
achirkin Dec 9, 2022
c5631bf
Fix includes style
achirkin Dec 9, 2022
fb88433
Use cmake-format
achirkin Dec 9, 2022
f64325b
Refactored warpsort module and made tests for all implementations in …
achirkin Dec 13, 2022
20d01d7
Resort to UVM when radix buffers are too big
achirkin Dec 14, 2022
4813bae
Adjust the dummy_block_sort_t to the changes in the warpsort impl
achirkin Dec 14, 2022
6cdb79a
Fix incorrect include
achirkin Dec 14, 2022
870fc86
Add benchmarks
achirkin Dec 14, 2022
2af45bf
Update CMakeLists.txt style
achirkin Dec 14, 2022
5b336ee
Update CMakeLists.txt style
achirkin Dec 14, 2022
b3e5d9c
Add mdspanified interface
achirkin Dec 15, 2022
164157b
Remove benchmarks for the legacy interface
achirkin Dec 15, 2022
69c81dd
Remove a TODO comment about a seemingly resolved bug
achirkin Dec 15, 2022
d64b12b
Merge remote-tracking branch 'rapidsai/branch-23.02' into enh-matrix-…
achirkin Dec 15, 2022
9d4476a
Fix the changed include extension
achirkin Dec 15, 2022
3e40435
Fix includes in tests
achirkin Dec 16, 2022
e20578e
Merge remote-tracking branch 'rapidsai/branch-23.02' into enh-matrix-…
achirkin Dec 16, 2022
b2c79f5
Merge branch 'branch-23.02' into enh-matrix-topk
achirkin Dec 20, 2022
98e2c2a
Address comments: bitonic_sort
achirkin Dec 20, 2022
af4c146
Replace stream argument with handle_t
achirkin Dec 20, 2022
471828e
rename files to select.* -> select_k.*
achirkin Dec 20, 2022
f6ff223
Use raft macros
achirkin Dec 20, 2022
066208d
Try to pass null and non-null arguments to select_k
achirkin Dec 20, 2022
aeaa1ef
Remove raw-pointer api from the public namespace
achirkin Dec 20, 2022
685b6bf
Updates public docs (add example usage)
achirkin Dec 21, 2022
5c42209
Merge remote-tracking branch 'rapidsai/branch-23.02' into enh-matrix-…
achirkin Jan 9, 2023
2cea50d
Add device_mem_resource
achirkin Jan 9, 2023
a31e61e
Add Doxygen docs
achirkin Jan 10, 2023
a8c5a70
Merge remote-tracking branch 'rapidsai/branch-23.02' into enh-matrix-…
achirkin Jan 10, 2023
8a5978b
Revert the memory_resource param changes in the detail namespace to a…
achirkin Jan 10, 2023
8e58cab
Merge remote-tracking branch 'rapidsai/branch-23.02' into enh-matrix-…
achirkin Jan 11, 2023
a01a75f
Remove device_mem_resource
achirkin Jan 11, 2023
c6256b7
Merge branch 'branch-23.02' into enh-matrix-topk
achirkin Jan 16, 2023
c25e859
Merge branch 'branch-23.02' into enh-matrix-topk
cjnolet Jan 19, 2023
6e56106
Merge branch 'branch-23.02' into enh-matrix-topk
achirkin Jan 20, 2023
c78d9b0
Reference a TODO issue
achirkin Jan 20, 2023
a55a6cb
Merge branch 'enh-matrix-topk' of github.com:achirkin/raft into enh-m…
achirkin Jan 20, 2023
307b113
Deprecation notice
achirkin Jan 20, 2023
c0ce160
Add [in] annotation to all arguments
achirkin Jan 20, 2023
e2cc7ad
Merge branch 'branch-23.02' into enh-matrix-topk
achirkin Jan 23, 2023
dc3043c
Merge branch 'branch-23.02' into enh-matrix-topk
cjnolet Jan 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactored warpsort module and made tests for all implementations in …
…detail namespace
  • Loading branch information
achirkin committed Dec 13, 2022
commit f64325bade684375bfdda483134b453b6cceed66
123 changes: 99 additions & 24 deletions cpp/include/raft/matrix/detail/select_warpsort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class warp_sort {
/** The number of elements to select. */
const int k;

/** Extra memory required per-block for keeping the state (shared or global). */
constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; }

/**
* Construct the warp_sort empty queue.
*
Expand Down Expand Up @@ -269,8 +272,9 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
using warp_sort<Capacity, Ascending, T, IdxT>::kDummy;
using warp_sort<Capacity, Ascending, T, IdxT>::kWarpWidth;
using warp_sort<Capacity, Ascending, T, IdxT>::k;
using warp_sort<Capacity, Ascending, T, IdxT>::mem_required;

__device__ warp_sort_filtered(int k, T limit)
explicit __device__ warp_sort_filtered(int k, T limit = kDummy)
: warp_sort<Capacity, Ascending, T, IdxT>(k), buf_len_(0), k_th_(limit)
{
#pragma unroll
Expand All @@ -280,9 +284,9 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
}
}

__device__ __forceinline__ explicit warp_sort_filtered(int k)
: warp_sort_filtered<Capacity, Ascending, T, IdxT>(k, kDummy)
__device__ __forceinline__ static auto init_blockwide(int k, uint8_t* = nullptr, T limit = kDummy)
{
return warp_sort_filtered<Capacity, Ascending, T, IdxT>{k, limit};
}

__device__ void add(T val, IdxT idx)
Expand Down Expand Up @@ -367,8 +371,9 @@ class warp_sort_distributed : public warp_sort<Capacity, Ascending, T, IdxT> {
using warp_sort<Capacity, Ascending, T, IdxT>::kDummy;
using warp_sort<Capacity, Ascending, T, IdxT>::kWarpWidth;
using warp_sort<Capacity, Ascending, T, IdxT>::k;
using warp_sort<Capacity, Ascending, T, IdxT>::mem_required;

__device__ warp_sort_distributed(int k, T limit)
explicit __device__ warp_sort_distributed(int k, T limit = kDummy)
: warp_sort<Capacity, Ascending, T, IdxT>(k),
buf_val_(kDummy),
buf_idx_(IdxT{}),
Expand All @@ -377,9 +382,9 @@ class warp_sort_distributed : public warp_sort<Capacity, Ascending, T, IdxT> {
{
}

__device__ __forceinline__ explicit warp_sort_distributed(int k)
: warp_sort_distributed<Capacity, Ascending, T, IdxT>(k, kDummy)
__device__ __forceinline__ static auto init_blockwide(int k, uint8_t* = nullptr, T limit = kDummy)
{
return warp_sort_distributed<Capacity, Ascending, T, IdxT>{k, limit};
}

__device__ void add(T val, IdxT idx)
Expand Down Expand Up @@ -468,7 +473,12 @@ class warp_sort_distributed_ext : public warp_sort<Capacity, Ascending, T, IdxT>
using warp_sort<Capacity, Ascending, T, IdxT>::kWarpWidth;
using warp_sort<Capacity, Ascending, T, IdxT>::k;

__device__ warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit)
constexpr static auto mem_required(uint32_t block_size) -> size_t
{
return (sizeof(T) + sizeof(IdxT)) * block_size;
}

__device__ warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy)
: warp_sort<Capacity, Ascending, T, IdxT>(k),
val_buf_(val_buf),
idx_buf_(idx_buf),
Expand All @@ -478,9 +488,21 @@ class warp_sort_distributed_ext : public warp_sort<Capacity, Ascending, T, IdxT>
val_buf_[laneId()] = kDummy;
}

__device__ __forceinline__ warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf)
: warp_sort_distributed_ext<Capacity, Ascending, T, IdxT>(k, val_buf, idx_buf, kDummy)
__device__ static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy)
{
T* val_buf = nullptr;
IdxT* idx_buf = nullptr;
if constexpr (alignof(T) >= alignof(IdxT)) {
val_buf = reinterpret_cast<T*>(shmem);
idx_buf = reinterpret_cast<IdxT*>(val_buf + blockDim.x);
} else {
idx_buf = reinterpret_cast<IdxT*>(shmem);
val_buf = reinterpret_cast<T*>(idx_buf + blockDim.x);
}
auto warp_offset = Pow2<WarpSize>::roundDown(threadIdx.x);
val_buf += warp_offset;
idx_buf += warp_offset;
return warp_sort_distributed_ext<Capacity, Ascending, T, IdxT>{k, val_buf, idx_buf, limit};
}

__device__ void add(T val, IdxT idx)
Expand Down Expand Up @@ -518,6 +540,7 @@ class warp_sort_distributed_ext : public warp_sort<Capacity, Ascending, T, IdxT>
merge_buf_();
buf_len_ = 0;
}
__syncthreads();
}

private:
Expand Down Expand Up @@ -562,8 +585,10 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
using warp_sort<Capacity, Ascending, T, IdxT>::kDummy;
using warp_sort<Capacity, Ascending, T, IdxT>::kWarpWidth;
using warp_sort<Capacity, Ascending, T, IdxT>::k;
using warp_sort<Capacity, Ascending, T, IdxT>::mem_required;

__device__ warp_sort_immediate(int k) : warp_sort<Capacity, Ascending, T, IdxT>(k), buf_len_(0)
explicit __device__ warp_sort_immediate(int k)
: warp_sort<Capacity, Ascending, T, IdxT>(k), buf_len_(0)
{
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
Expand All @@ -572,6 +597,11 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
}
}

__device__ __forceinline__ static auto init_blockwide(int k, uint8_t* = nullptr)
{
return warp_sort_immediate<Capacity, Ascending, T, IdxT>{k};
}

__device__ void add(T val, IdxT idx)
{
// NB: the loop is used here to ensure the constant indexing,
Expand Down Expand Up @@ -631,12 +661,8 @@ class block_sort {
using queue_t = WarpSortWarpWide<Capacity, Ascending, T, IdxT>;

template <typename... Args>
__device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...)
__device__ block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...))
{
val_smem_ = reinterpret_cast<T*>(smem_buf);
const int num_of_warp = subwarp_align::div(blockDim.x);
idx_smem_ = reinterpret_cast<IdxT*>(
smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k));
}

__device__ void add(T val, IdxT idx) { queue_.add(val, idx); }
Expand All @@ -647,22 +673,26 @@ class block_sort {
*
* Here we tree-merge the results using the shared memory and block sync.
*/
__device__ void done()
__device__ void done(uint8_t* smem_buf)
{
queue_.done();

int nwarps = subwarp_align::div(blockDim.x);
auto val_smem = reinterpret_cast<T*>(smem_buf);
auto idx_smem = reinterpret_cast<IdxT*>(
smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k));

const int warp_id = subwarp_align::div(threadIdx.x);
// NB: there is no need for the second __synchthreads between .load_sorted and .store:
// we shift the pointers every iteration, such that individual warps either access the same
// locations or do not overlap with any of the other warps. The access patterns within warps
// are different for the two functions, but .load_sorted implies warp sync at the end, so
// there is no need for __syncwarp either.
for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1;
nwarps > 1;
for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1;
nwarps = split, split = (nwarps + 1) >> 1) {
if (warp_id < nwarps && warp_id >= split) {
int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k;
queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift);
queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift);
}
__syncthreads();

Expand All @@ -672,7 +702,7 @@ class block_sort {
// The last argument serves as a condition for loading
// -- to make sure threads within a full warp do not diverge on `bitonic::merge()`
queue_.load_sorted(
val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split);
val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split);
}
}
}
Expand All @@ -686,8 +716,6 @@ class block_sort {
private:
using subwarp_align = Pow2<queue_t::kWarpWidth>;
queue_t queue_;
T* val_smem_;
IdxT* idx_smem_;
};

/**
Expand All @@ -705,7 +733,10 @@ __launch_bounds__(256) __global__
void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx)
{
extern __shared__ __align__(256) uint8_t smem_buf_bytes[];
block_sort<WarpSortClass, Capacity, Ascending, T, IdxT> queue(k, smem_buf_bytes);
using bq_t = block_sort<WarpSortClass, Capacity, Ascending, T, IdxT>;
uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr;
bq_t queue(k, warp_smem);

in += blockIdx.y * len;
if (in_idx != nullptr) { in_idx += blockIdx.y * len; }

Expand All @@ -716,7 +747,7 @@ __launch_bounds__(256) __global__
(i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i);
}

queue.done();
queue.done(smem_buf_bytes);
const int block_id = blockIdx.x + gridDim.x * blockIdx.y;
queue.store(out + block_id * k, out_idx + block_id * k);
}
Expand Down Expand Up @@ -827,6 +858,18 @@ struct LaunchThreshold<warp_sort_filtered> {
static constexpr int len_factor_for_single_block = 32;
};

template <>
struct LaunchThreshold<warp_sort_distributed> {
static constexpr int len_factor_for_multi_block = 2;
static constexpr int len_factor_for_single_block = 32;
};

template <>
struct LaunchThreshold<warp_sort_distributed_ext> {
static constexpr int len_factor_for_multi_block = 2;
static constexpr int len_factor_for_single_block = 32;
};

template <>
struct LaunchThreshold<warp_sort_immediate> {
static constexpr int len_factor_for_choosing = 4;
Expand Down Expand Up @@ -943,6 +986,8 @@ void select_k_(int num_of_block,
int block_dim = num_of_warp * warp_width;
int smem_size = calc_smem_size_for_block_wide<T, IdxT>(num_of_warp, k);

smem_size = std::max<int>(smem_size, WarpSortClass<1, true, T, IdxT>::mem_required(block_dim));

launch_setup<WarpSortClass, T, IdxT>::kernel(k,
select_min,
batch_size,
Expand Down Expand Up @@ -973,6 +1018,36 @@ void select_k_(int num_of_block,
}
}

template <typename T, typename IdxT, template <int, bool, typename, typename> class WarpSortClass>
void select_k_impl(const T* in,
const IdxT* in_idx,
size_t batch_size,
size_t len,
int k,
T* out,
IdxT* out_idx,
bool select_min,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr)
{
int num_of_block = 0;
int num_of_warp = 0;
calc_launch_parameter<WarpSortClass, T, IdxT>(batch_size, len, k, &num_of_block, &num_of_warp);

select_k_<WarpSortClass, T, IdxT>(num_of_block,
num_of_warp,
in,
in_idx,
batch_size,
len,
k,
out,
out_idx,
select_min,
stream,
mr);
}

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down
11 changes: 7 additions & 4 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
Ascending,
float,
IdxT>;
block_sort_t queue(k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T));
block_sort_t queue(k);

{
using align_warp = Pow2<WarpSize>;
Expand Down Expand Up @@ -781,7 +781,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}

// finalize and store selected neighbours
queue.done();
__syncthreads();
queue.done(interleaved_scan_kernel_smem);
queue.store(distances, neighbors);
}

Expand Down Expand Up @@ -832,8 +833,10 @@ void launch_kernel(Lambda lambda,
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
int smem_size = query_smem_elems * sizeof(T);
constexpr int kSubwarpSize = std::min<int>(Capacity, WarpSize);
smem_size += raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide<AccT, IdxT>(
kThreadsPerBlock / kSubwarpSize, k);
auto block_merge_mem =
raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide<AccT, IdxT>(
kThreadsPerBlock / kSubwarpSize, k);
smem_size += std::max<int>(smem_size, block_merge_mem);

// power-of-two less than cuda limit (for better addr alignment)
constexpr uint32_t kMaxGridY = 32768;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows,
constexpr OutT kDummy = upper_bound<OutT>();
OutT query_kth = kDummy;
if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); }
local_topk_t block_topk(topk, smem_buf, query_kth);
local_topk_t block_topk(topk, nullptr, query_kth);
OutT early_stop_limit = kDummy;
switch (metric) {
// If the metric is non-negative, we can use the query_kth approximation as an early stop
Expand Down Expand Up @@ -843,7 +843,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows,
if constexpr (kManageLocalTopK) {
// sync threads before the topk merging operation, because we reuse smem_buf
__syncthreads();
block_topk.done();
block_topk.done(smem_buf);
block_topk.store(out_scores, out_indices);
if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); }
} else {
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ if(BUILD_TESTS)
test/matrix/matrix.cu
test/matrix/norm.cu
test/matrix/reverse.cu
test/matrix/select.cu
test/matrix/slice.cu
test/matrix/triangular.cu
test/spectral_matrix.cu
Expand Down
Loading