Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivf_flat::index: hide implementation details #747

Merged
merged 147 commits into from
Aug 24, 2022
Merged
Changes from 1 commit
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
35ab60d
inital commit and formatting cleanup
achirkin May 13, 2022
24e8c4d
update save/load index function to work with cuann benchmark suite, s…
achirkin May 16, 2022
cb8bcd2
Added benchmarks.
achirkin May 16, 2022
884723c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
8c4a0a0
Add a missing parameter docs
achirkin May 17, 2022
070fd05
Adapt to the changes in the warpsort api
achirkin May 17, 2022
83b6630
cleanup: use WarpSize constant
achirkin May 17, 2022
3a2703c
cleanup: remove unnecessary helpers
achirkin May 17, 2022
31bbaec
Use a more efficient warp_sort_filtered
achirkin May 17, 2022
4b40181
Recover files that have only non-relevant changes to reduce the size …
achirkin May 17, 2022
7e3041c
wip: replacing explicit allocations with rmm buffers
achirkin May 17, 2022
f6556b7
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
f75761f
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 18, 2022
dd558b4
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
94b3cbe
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
2be45a9
wip: replace cudaMemcpy with raft::copy
achirkin May 18, 2022
30c32a9
Simplified some cudaMemcpy invocations
achirkin May 18, 2022
c8e7b4d
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 19, 2022
150a438
Refactoring with helper functions
achirkin May 19, 2022
ddfb8cc
Make the scratch buf 3x L2 cache size
achirkin May 19, 2022
b788e2e
Remove serialization code for now
achirkin May 19, 2022
3e1c14d
remove obsolete comment
achirkin May 19, 2022
a001999
Add a missing sync
achirkin May 19, 2022
2d08271
Rename ann_quantized_faiss
achirkin May 19, 2022
0f88aaa
wip from manual allocations to rmm: updated some parts with pointer r…
achirkin May 19, 2022
363dfc9
wip from manual allocations to rmm
achirkin May 19, 2022
e5399f8
fix style
achirkin May 19, 2022
306f5bf
Set minimum memory pool size in radix_topk to 256 bytes
achirkin May 20, 2022
fd7d2ba
wip malloc-to-rmm: removed most of the manual allocations
achirkin May 20, 2022
403667a
misc cleanup
achirkin May 20, 2022
4c6d563
Refactoing; used raft::handle in place of cublas handle everywhere
achirkin May 20, 2022
3ae52ea
Fix the value type at runtime (use templates instead of runtime dtype)
achirkin May 20, 2022
6fecd7f
ceildiv
achirkin May 20, 2022
174854f
Use rmm's memory pool in place of explicitly allocated buffers
achirkin May 20, 2022
b45b14c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 20, 2022
ca1aaad
Use raft logging
achirkin May 24, 2022
4228a02
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
70d84ec
Updated logging and nvtx markers
achirkin May 24, 2022
f9c12f8
clang-format
achirkin May 24, 2022
17968e4
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
957ac94
Use the recommended logger header
achirkin May 24, 2022
ccfbccc
Use warpsort for smaller k
achirkin May 25, 2022
7819397
Using raft helpers
achirkin May 25, 2022
510c467
Determine the template parameters Capacity and Veclen recursively
achirkin May 25, 2022
c5087be
wip: refactoring and reducing duplicate calls
achirkin May 26, 2022
f850a4a
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 27, 2022
c5f1c89
Refactor and document ann_ivf_flat_kernel
achirkin May 27, 2022
7b2b9ff
Documenting and refactoring the kernel
achirkin May 27, 2022
913edfb
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 30, 2022
b1208ed
Add a case of high dimensionality
achirkin May 31, 2022
a30ade5
Add more sync into the test to detect device errors
achirkin May 31, 2022
84db732
Add more sync into the test to detect device errors
achirkin May 31, 2022
346afb2
Allow large batch sizes and document more functions
achirkin May 31, 2022
fc201b5
Add a lower bound on expected recall
achirkin May 31, 2022
4021ea2
Compure required memory dynamically
achirkin May 31, 2022
ea8b1c4
readability quickfix
achirkin May 31, 2022
d8a034a
Correct the smem size for the warpsort and add launch bounds
achirkin May 31, 2022
d97d248
Add couple checks against floating point exceptions
achirkin Jun 1, 2022
2e64037
Don't run kmeans on empty dataset
achirkin Jun 2, 2022
9ed50ac
Order all ops by a cuda stream
achirkin Jun 2, 2022
1f9352c
Update comments
achirkin Jun 2, 2022
c048af2
Suggest replacing _cuann_sqsum
achirkin Jun 2, 2022
96f39a8
wip: refactoting utils
achirkin Jun 2, 2022
888daeb
minor comments
achirkin Jun 2, 2022
e6ff267
ann_utils refactoring, docs, and clang-tidy
achirkin Jun 3, 2022
426f713
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin Jun 7, 2022
bacb402
Refactor tests and reduce their memory footprint
achirkin Jun 7, 2022
4042b28
Refactored and documents ann_kmeans_balanced
achirkin Jun 7, 2022
bb5726b
Use memory_resource for temp data in kmeans
achirkin Jun 7, 2022
810c26b
Address clang-tidy and other refactoring suggestions
achirkin Jun 8, 2022
042c410
Move part of the index building onto gpu
achirkin Jun 8, 2022
7ace0fb
Document the index building kernel
achirkin Jun 15, 2022
e9c0d49
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 15, 2022
3515715
Added a dims padding todo
achirkin Jun 15, 2022
6bd6560
Move kmeans-related allocations and routines to ann_kmeans_balanced.cuh
achirkin Jun 15, 2022
2811814
Add documentation to the build_optimized_kmeans
achirkin Jun 15, 2022
fc3e46e
Using mdarrays and structured index
achirkin Jun 16, 2022
fb8c4b1
Fixed a memory leak and introduced a few assertions to check pointer …
achirkin Jun 17, 2022
f3b2cb2
Merge branch 'branch-22.08' into fea-knn-ivf-flat
cjnolet Jun 17, 2022
092d428
Refactoring build_optimized_kmeans
achirkin Jun 17, 2022
fbcb16b
A few smaller refactorings for kmeans
achirkin Jun 17, 2022
29ca199
Add docs to public methods of the handle
achirkin Jun 20, 2022
38b3cec
Made the metric be a part of the index struct and set the greater_ = …
achirkin Jun 21, 2022
d19bb5f
Do not persist grid_dim_x between searches
achirkin Jun 21, 2022
9094707
Refactor names according to clang-tidy
achirkin Jun 21, 2022
325e201
Refactor the usage of stream and params
achirkin Jun 21, 2022
2a3eb33
Refactor api to have symmetric index/search params
achirkin Jun 21, 2022
867beca
refactor away ivf_flat_index
achirkin Jun 22, 2022
059a6c0
Add the memory resource argument to warp_sort_topk
achirkin Jun 22, 2022
df17b5b
update docs
achirkin Jun 22, 2022
fe9ced1
Allow empty mesoclusters
achirkin Jun 23, 2022
91fdcbb
Add low-dimensional and non-veclen-aligned-dimensional test cases
achirkin Jun 23, 2022
be14c63
Refactor and document loadAndComputeDist
achirkin Jun 23, 2022
eeb4601
Minor renamings
achirkin Jun 23, 2022
025e5a5
Add 8bit int types to knn benchmarks
achirkin Jun 23, 2022
3821366
Fix incorrect data mapping for int8 types
achirkin Jun 24, 2022
d596842
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 24, 2022
a29baa7
Introduce kIndexGroupSize constant
achirkin Jun 27, 2022
546bef8
Cleanup ann_quantized
achirkin Jun 27, 2022
32d0d2e
Add several type aliases and helpers for creating mdarrays
achirkin Jun 27, 2022
5f427c0
Remove unnecessary inlines and fix docs
achirkin Jun 28, 2022
c581fe2
More refactoring and a few forceinlines
achirkin Jun 28, 2022
805e78c
Add a helper for creating pool_memory_resource when it makes sense
achirkin Jun 29, 2022
a4973e6
Force move the mdarrays when creating index to avoid copying them
achirkin Jun 29, 2022
68c267e
Minor refactorings
achirkin Jun 29, 2022
f2b8ed8
Add nvtx annotations to the outermost ANN calls for better performanc…
achirkin Jun 29, 2022
f91c7f7
Add a few more test cases and annotations for them
achirkin Jun 29, 2022
84b1c5b
Fix a typo
achirkin Jun 29, 2022
afc1f6a
Move ensure_integral_extents to the detail folder
achirkin Jun 30, 2022
3a10f86
Lift the requirement to have query pointers aligned with Veclen
achirkin Jun 30, 2022
9f5c64c
Merge branch 'branch-22.08' into enh-mdarray-helpers
achirkin Jun 30, 2022
1afd667
Use move semantics for the index everywhere, but try to keep it const…
achirkin Jun 30, 2022
73ce9e1
Update documentation
achirkin Jun 30, 2022
2a45645
Remove the debug path USE_FAISS
achirkin Jun 30, 2022
75a48b4
Add a type trait for checking if the conversion between two numeric t…
achirkin Jul 1, 2022
ed25cae
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 1, 2022
388200c
Support 32bit and unsigned indices in bruteforce KNN
achirkin Jul 1, 2022
f08df83
Merge branch 'enh-mdarray-helpers' into fea-knn-ivf-flat
achirkin Jul 1, 2022
9200886
Merge branch 'enh-knn-bruteforce-uint32' into fea-knn-ivf-flat
achirkin Jul 1, 2022
14bfe02
Make index type a template parameter
achirkin Jul 1, 2022
1283cbe
Revert the api changes as much as possible and deprecate the old api
achirkin Jul 1, 2022
e73b259
Remove the stream argument from the public API
achirkin Jul 4, 2022
8e7ffb8
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
5f5dc0d
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
03ebbe0
Simplify kmeans::predict a little bit
achirkin Jul 6, 2022
cde7f97
Factor out predict from the other ops in kmeans for use outside of th…
achirkin Jul 7, 2022
305bbcd
Add new function extend(index, new_vecs, new_inds) to ivf_flat
achirkin Jul 20, 2022
76c383f
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 21, 2022
7f640a9
Improve the docs
achirkin Jul 21, 2022
2e9eda5
Fix using non-existing log function
achirkin Jul 21, 2022
dc62a0f
Hide all data components from ifv_flat::index and expose immutable views
achirkin Jul 21, 2022
fb841c3
Replace thurst::exclusive_scan with thrust::inclusive_scan to avoid a…
achirkin Jul 22, 2022
04bb5dc
Merge branch 'fea-knn-ivf-flat' into enh-knn-ivf-flat-hide-impl
achirkin Jul 22, 2022
c95ea85
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0c72ee8
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0196695
Make helper overloads inline for linking in cuml
achirkin Jul 22, 2022
eb15639
Split processing.hpp into *.cuh and *.hpp to avoid incomplete types
achirkin Jul 22, 2022
e4b2b39
WIP: investigating segmentation fault in cuml test
achirkin Jul 25, 2022
6bc0fcb
Revert the wip-changes from the last commit
achirkin Jul 26, 2022
f599aaf
Merge remote-tracking branch 'origin/fea-knn-ivf-flat' into enh-knn-i…
achirkin Jul 26, 2022
a191410
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Jul 28, 2022
317ddf3
Enhance documentation
achirkin Jul 28, 2022
114fb63
Fix couple typos in docs
achirkin Jul 28, 2022
1d283ae
Change the data indexing to size_t to make sure the total size (size*…
achirkin Jul 28, 2022
a9bd2d6
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Aug 2, 2022
f9d55a7
Make ivf_flat::index look a little bit more like knn::sparse api
achirkin Aug 2, 2022
fef6dac
Test both overloads of
achirkin Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Lift the requirement to have query pointers aligned with Veclen
  • Loading branch information
achirkin committed Jun 30, 2022
commit 3a10f86930a2468fb1e15d6275800a8529aed008
214 changes: 94 additions & 120 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_type.hpp>
#include <raft/pow2_utils.cuh>
#include <raft/vectorized.cuh>

#ifdef USE_FAISS
#include <faiss/gpu/utils/Comparators.cuh>
Expand All @@ -52,70 +53,49 @@ using raft::spatial::knn::ivf_flat::search_params;
constexpr int kThreadsPerBlock = 128;

/**
* @brief Copy Veclen elements of type T from `query` to `query_shared` at position `loadDim *
* Veclen`.
* @brief Copy `n` elements per block from one place to another.
*
* @param[in] query a pointer to a device global memory
* @param[out] query_shared a pointer to a device shared memory
* @param loadDim position at which to start copying elements.
* @param[out] out target pointer (unique per block)
* @param[in] in source pointer
* @param n number of elements to copy
*/
template <typename T, int Veclen>
__device__ __forceinline__ void queryLoadToShmem(const T* const& query,
T* query_shared,
const int loadDim)
template <int VecBytes = 16, typename T>
__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n)
{
T queryReg[Veclen];
const int loadIndex = loadDim * Veclen;
ldg(queryReg, query + loadIndex);
sts(&query_shared[loadIndex], queryReg);
}

template <>
__device__ __forceinline__ void queryLoadToShmem<uint8_t, 8>(const uint8_t* const& query,
uint8_t* query_shared,
const int loadDim)
{
constexpr int veclen = 2; // 8 uint8_t
uint32_t queryReg[veclen];
const int loadIndex = loadDim * veclen;
ldg(queryReg, reinterpret_cast<uint32_t const*>(query) + loadIndex);
sts(reinterpret_cast<uint32_t*>(query_shared) + loadIndex, queryReg);
}

template <>
__device__ __forceinline__ void queryLoadToShmem<uint8_t, 16>(const uint8_t* const& query,
uint8_t* query_shared,
const int loadDim)
{
constexpr int veclen = 4; // 16 uint8_t
uint32_t queryReg[veclen];
const int loadIndex = loadDim * veclen;
ldg(queryReg, reinterpret_cast<uint32_t const*>(query) + loadIndex);
sts(reinterpret_cast<uint32_t*>(query_shared) + loadIndex, queryReg);
}

template <>
__device__ __forceinline__ void queryLoadToShmem<int8_t, 8>(const int8_t* const& query,
int8_t* query_shared,
const int loadDim)
{
constexpr int veclen = 2; // 8 int8_t
int32_t queryReg[veclen];
const int loadIndex = loadDim * veclen;
ldg(queryReg, reinterpret_cast<int32_t const*>(query) + loadIndex);
sts(reinterpret_cast<int32_t*>(query_shared) + loadIndex, queryReg);
}

template <>
__device__ __forceinline__ void queryLoadToShmem<int8_t, 16>(const int8_t* const& query,
int8_t* query_shared,
const int loadDim)
{
constexpr int veclen = 4; // 16 int8_t
int32_t queryReg[veclen];
const int loadIndex = loadDim * veclen;
ldg(queryReg, reinterpret_cast<int32_t const*>(query) + loadIndex);
sts(reinterpret_cast<int32_t*>(query_shared) + loadIndex, queryReg);
constexpr int VecElems = VecBytes / sizeof(T); // NOLINT
using align_bytes = Pow2<(size_t)VecBytes>;
if constexpr (VecElems > 1) {
using align_elems = Pow2<VecElems>;
if (!align_bytes::areSameAlignOffsets(out, in)) {
return copy_vectorized<(VecBytes >> 1), T>(out, in, n);
}
{ // process unaligned head
uint32_t head = align_bytes::roundUp(in) - in;
if (head > 0) {
copy_vectorized<sizeof(T), T>(out, in, head);
n -= head;
in += head;
out += head;
}
}
{ // process main part vectorized
using vec_t = typename IOType<T, VecElems>::Type;
copy_vectorized<sizeof(vec_t), vec_t>(
reinterpret_cast<vec_t*>(out), reinterpret_cast<const vec_t*>(in), align_elems::div(n));
}
{ // process unaligned tail
uint32_t tail = align_elems::mod(n);
if (tail > 0) {
n -= tail;
copy_vectorized<sizeof(T), T>(out + n, in + n, tail);
}
}
}
if constexpr (VecElems <= 1) {
for (int i = threadIdx.x; i < n; i += blockDim.x) {
out[i] = in[i];
}
}
}

/**
Expand Down Expand Up @@ -213,7 +193,7 @@ struct loadAndComputeDist {
for (int k = 0; k < Veclen; k++) {
compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]);
}
} // end for d < dim - dimBlocks
}
}
};

Expand Down Expand Up @@ -292,7 +272,7 @@ struct loadAndComputeDist<kUnroll, Lambda, uint8_veclen, uint8_t, uint32_t> {
uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize);
compute_dist(dist, q, enc[k]);
}
} // end for d < dim - dimBlocks
}
}
};

Expand Down Expand Up @@ -354,7 +334,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 4, uint8_t, uint32_t> {
uint32_t enc = reinterpret_cast<unsigned const*>(data)[lane_id];
uint32_t q = shfl(queryReg, d / veclen, WarpSize);
compute_dist(dist, q, enc);
} // end for d < dim - dimBlocks
}
}
};

Expand Down Expand Up @@ -553,7 +533,7 @@ struct loadAndComputeDist<kUnroll, Lambda, int8_veclen, int8_t, int32_t> {
int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int;
compute_dist(dist, q, enc[k]);
}
} // end for d < dim - dimBlocks
}
}
};

Expand Down Expand Up @@ -672,7 +652,7 @@ template <int Capacity, int Veclen, bool Ascending, typename T, typename AccT, t
__global__ void __launch_bounds__(kThreadsPerBlock)
interleaved_scan_kernel(Lambda compute_dist,
const uint32_t query_smem_elems,
const T* queries,
const T* query,
const uint32_t* coarse_index,
const uint32_t* list_index,
const T* list_data,
Expand All @@ -685,6 +665,24 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
// Using shared memory for the (part of the) query;
// This allows to save on global memory bandwidth when reading index and query
// data at the same time.
// Its size is `query_smem_elems`.
T* query_shared = reinterpret_cast<T*>(interleaved_scan_kernel_smem);
// Make the query input and output point to this block's shared query
{
const int query_id = blockIdx.y;
query += query_id * dim;
neighbors += query_id * k * gridDim.x + blockIdx.x * k;
distances += query_id * k * gridDim.x + blockIdx.x * k;
coarse_index += query_id * n_probes;
}

// Copy a part of the query into shared memory for faster processing
copy_vectorized(query_shared, query, std::min(dim, query_smem_elems));
__syncthreads();

#ifdef USE_FAISS
// temporary use of FAISS blockSelect for development purpose of k <= 32
// for comparison purpose
Expand All @@ -702,86 +700,62 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
queue(identity, keyMax, smemK, smemV, k);

#else
topk::block_sort<topk::warp_sort_immediate, Capacity, Ascending, float, size_t> queue(
topk::block_sort<topk::warp_sort_filtered, Capacity, Ascending, float, size_t> queue(
k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T));
#endif

const int query_id = blockIdx.y;
{
// Using shared memory for the (part of the) query;
// This allows to save on global memory bandwidth when reading index and query
// data at the same time.
// Its size is `query_smem_elems`.
T* query_shared = reinterpret_cast<T*>(interleaved_scan_kernel_smem);

using align_warp = Pow2<WarpSize>;
const int lane_id = align_warp::mod(threadIdx.x);
const int warp_id = align_warp::div(threadIdx.x);

/// Set the address
auto query = queries + query_id * dim;
constexpr int kGroupSize = WarpSize;

// How many full warps needed to compute the distance (without remainder)
const int full_warps_along_dim = align_warp::roundDown(dim);
const uint32_t full_warps_along_dim = align_warp::roundDown(dim);

int shm_assisted_dim = (dim < query_smem_elems) ? dim : query_smem_elems;

// load the query data from global to shared memory
for (int i = threadIdx.x; i * Veclen < shm_assisted_dim; i += blockDim.x) {
queryLoadToShmem<T, Veclen>(query, query_shared, i);
}
__syncthreads();
shm_assisted_dim = (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim;
const uint32_t shm_assisted_dim =
(dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim;

// Every CUDA block scans one cluster at a time.
for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) {
const uint32_t list_id =
coarse_index[query_id * n_probes + probe_id]; // The id of cluster(list)

/**
* Uses shared memory
*/
// The start address of the full value of vector for each cluster(list) interleaved
auto vecsBase = list_data + size_t(list_prefix_interleave[list_id]) * dim;
// The start address of index of vector for each cluster(list) interleaved
auto indexBase = list_index + list_prefix_interleave[list_id];
const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list)
const size_t list_offset = list_prefix_interleave[list_id];

// The number of vectors in each cluster(list); [nlist]
const uint32_t list_length = list_lengths[list_id];

// The number of interleaved groups to be processed
const uint32_t num_groups = ceildiv<uint32_t>(list_length, WarpSize);
const uint32_t num_groups =
align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2

constexpr int kUnroll = WarpSize / Veclen;
constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize;
// Every warp reads WarpSize vectors and computes the distances to them.
// Then, the distances and corresponding ids are distributed among the threads,
// and each thread adds one (id, dist) pair to the filtering queue.
for (uint32_t block = warp_id; block < num_groups; block += kNumWarps) {
for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups;
group_id += kNumWarps) {
AccT dist = 0;
// This is where this warp begins reading data (start position of an interleaved group)
const T* data = list_data + (list_offset + group_id * kIndexGroupSize) * dim;

// This is the vector a given lane/thread handles
const uint32_t vec = block * WarpSize + lane_id;
bool valid = vec < list_length;
size_t idx = (valid) ? (size_t)indexBase[vec] : (size_t)lane_id;
// This is where this warp begins reading data
const T* data =
vecsBase + size_t(block) * kGroupSize * dim; // Start position of this block
const uint32_t vec_id = group_id * WarpSize + lane_id;
const bool valid = vec_id < list_length;

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid) {
for (int pos = 0; pos < shm_assisted_dim; pos += WarpSize) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
pos += WarpSize, data += kIndexGroupSize * WarpSize) {
lc.runLoadShmemCompute(data, query_shared, lane_id, pos);
data += WarpSize * kGroupSize;
}
}

if (dim > query_smem_elems) {
// The default path - using shfl ops - for dimensions beyond query_smem_elems
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { //
for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) {
lc.runLoadShflAndCompute(data, query, pos, lane_id);
}
lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim);
Expand All @@ -790,15 +764,16 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if (valid) {
loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist);
for (int pos = full_warps_along_dim; pos < dim;
pos += Veclen, data += kGroupSize * Veclen) {
pos += Veclen, data += kIndexGroupSize * Veclen) {
lc.runLoadShmemCompute(data, query_shared, lane_id, pos);
}
}
}

// Enqueue one element per thread
constexpr float kDummy = Ascending ? upper_bound<float>() : lower_bound<float>();
float val = valid ? static_cast<float>(dist) : kDummy;
const float val = valid ? static_cast<float>(dist) : kDummy;
const size_t idx = valid ? static_cast<size_t>(list_index[list_offset + vec_id]) : 0;
queue.add(val, idx);
}
}
Expand All @@ -808,15 +783,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
#ifdef USE_FAISS
queue.reduce();
for (int i = threadIdx.x; i < k; i += kThreadsPerBlock) {
neighbors[query_id * k * gridDim.x + blockIdx.x * k + i] = (size_t)smemV[i];
distances[query_id * k * gridDim.x + blockIdx.x * k + i] = smemK[i];
neighbors[i] = (size_t)smemV[i];
distances[i] = smemK[i];
}
#else
queue.done();
queue.store(distances + query_id * k * gridDim.x + blockIdx.x * k,
neighbors + query_id * k * gridDim.x + blockIdx.x * k);
queue.store(distances, neighbors);
#endif
} // end kernel
}

/**
* Configure the gridDim.x to maximize GPU occupancy, but reduce the output size
Expand Down Expand Up @@ -850,8 +824,6 @@ void launch_kernel(Lambda lambda,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
RAFT_EXPECTS(reinterpret_cast<size_t>(queries) % (Veclen * sizeof(T)) == 0,
"Queries data is not aligned to the vector load size (Veclen).");
RAFT_EXPECTS(Veclen == index.veclen,
"Configured Veclen does not match the index interleaving pattern.");
constexpr auto kKernel = interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, Lambda>;
Expand Down Expand Up @@ -1230,6 +1202,8 @@ void search_impl(const handle_t& handle,
stream,
search_mr);
} else {
// NB: this branch can only be triggered once `ivfflat_interleaved_scan` above supports larger
// `k` values (kMaxCapacity limit as a dependency of topk::block_sort)
topk::radix_topk<AccT, size_t, 11, 512>(refined_distances_dev.data(),
refined_indices_dev.data(),
n_queries,
Expand Down