Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivf_flat::index: hide implementation details #747

Merged
merged 147 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
35ab60d
inital commit and formatting cleanup
achirkin May 13, 2022
24e8c4d
update save/load index function to work with cuann benchmark suite, s…
achirkin May 16, 2022
cb8bcd2
Added benchmarks.
achirkin May 16, 2022
884723c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
8c4a0a0
Add a missing parameter docs
achirkin May 17, 2022
070fd05
Adapt to the changes in the warpsort api
achirkin May 17, 2022
83b6630
cleanup: use WarpSize constant
achirkin May 17, 2022
3a2703c
cleanup: remove unnecessary helpers
achirkin May 17, 2022
31bbaec
Use a more efficient warp_sort_filtered
achirkin May 17, 2022
4b40181
Recover files that have only non-relevant changes to reduce the size …
achirkin May 17, 2022
7e3041c
wip: replacing explicit allocations with rmm buffers
achirkin May 17, 2022
f6556b7
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 17, 2022
f75761f
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 18, 2022
dd558b4
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
94b3cbe
Update cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
achirkin May 18, 2022
2be45a9
wip: replace cudaMemcpy with raft::copy
achirkin May 18, 2022
30c32a9
Simplified some cudaMemcpy invocations
achirkin May 18, 2022
c8e7b4d
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 19, 2022
150a438
Refactoring with helper functions
achirkin May 19, 2022
ddfb8cc
Make the scratch buf 3x L2 cache size
achirkin May 19, 2022
b788e2e
Remove serialization code for now
achirkin May 19, 2022
3e1c14d
remove obsolete comment
achirkin May 19, 2022
a001999
Add a missing sync
achirkin May 19, 2022
2d08271
Rename ann_quantized_faiss
achirkin May 19, 2022
0f88aaa
wip from manual allocations to rmm: updated some parts with pointer r…
achirkin May 19, 2022
363dfc9
wip from manual allocations to rmm
achirkin May 19, 2022
e5399f8
fix style
achirkin May 19, 2022
306f5bf
Set minimum memory pool size in radix_topk to 256 bytes
achirkin May 20, 2022
fd7d2ba
wip malloc-to-rmm: removed most of the manual allocations
achirkin May 20, 2022
403667a
misc cleanup
achirkin May 20, 2022
4c6d563
Refactoing; used raft::handle in place of cublas handle everywhere
achirkin May 20, 2022
3ae52ea
Fix the value type at runtime (use templates instead of runtime dtype)
achirkin May 20, 2022
6fecd7f
ceildiv
achirkin May 20, 2022
174854f
Use rmm's memory pool in place of explicitly allocated buffers
achirkin May 20, 2022
b45b14c
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 20, 2022
ca1aaad
Use raft logging
achirkin May 24, 2022
4228a02
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
70d84ec
Updated logging and nvtx markers
achirkin May 24, 2022
f9c12f8
clang-format
achirkin May 24, 2022
17968e4
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 24, 2022
957ac94
Use the recommended logger header
achirkin May 24, 2022
ccfbccc
Use warpsort for smaller k
achirkin May 25, 2022
7819397
Using raft helpers
achirkin May 25, 2022
510c467
Determine the template parameters Capacity and Veclen recursively
achirkin May 25, 2022
c5087be
wip: refactoring and reducing duplicate calls
achirkin May 26, 2022
f850a4a
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 27, 2022
c5f1c89
Refactor and document ann_ivf_flat_kernel
achirkin May 27, 2022
7b2b9ff
Documenting and refactoring the kernel
achirkin May 27, 2022
913edfb
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin May 30, 2022
b1208ed
Add a case of high dimensionality
achirkin May 31, 2022
a30ade5
Add more sync into the test to detect device errors
achirkin May 31, 2022
84db732
Add more sync into the test to detect device errors
achirkin May 31, 2022
346afb2
Allow large batch sizes and document more functions
achirkin May 31, 2022
fc201b5
Add a lower bound on expected recall
achirkin May 31, 2022
4021ea2
Compure required memory dynamically
achirkin May 31, 2022
ea8b1c4
readability quickfix
achirkin May 31, 2022
d8a034a
Correct the smem size for the warpsort and add launch bounds
achirkin May 31, 2022
d97d248
Add couple checks against floating point exceptions
achirkin Jun 1, 2022
2e64037
Don't run kmeans on empty dataset
achirkin Jun 2, 2022
9ed50ac
Order all ops by a cuda stream
achirkin Jun 2, 2022
1f9352c
Update comments
achirkin Jun 2, 2022
c048af2
Suggest replacing _cuann_sqsum
achirkin Jun 2, 2022
96f39a8
wip: refactoting utils
achirkin Jun 2, 2022
888daeb
minor comments
achirkin Jun 2, 2022
e6ff267
ann_utils refactoring, docs, and clang-tidy
achirkin Jun 3, 2022
426f713
Merge branch 'branch-22.06' into fea-knn-ivf-flat
achirkin Jun 7, 2022
bacb402
Refactor tests and reduce their memory footprint
achirkin Jun 7, 2022
4042b28
Refactored and documents ann_kmeans_balanced
achirkin Jun 7, 2022
bb5726b
Use memory_resource for temp data in kmeans
achirkin Jun 7, 2022
810c26b
Address clang-tidy and other refactoring suggestions
achirkin Jun 8, 2022
042c410
Move part of the index building onto gpu
achirkin Jun 8, 2022
7ace0fb
Document the index building kernel
achirkin Jun 15, 2022
e9c0d49
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 15, 2022
3515715
Added a dims padding todo
achirkin Jun 15, 2022
6bd6560
Move kmeans-related allocations and routines to ann_kmeans_balanced.cuh
achirkin Jun 15, 2022
2811814
Add documentation to the build_optimized_kmeans
achirkin Jun 15, 2022
fc3e46e
Using mdarrays and structured index
achirkin Jun 16, 2022
fb8c4b1
Fixed a memory leak and introduced a few assertions to check pointer …
achirkin Jun 17, 2022
f3b2cb2
Merge branch 'branch-22.08' into fea-knn-ivf-flat
cjnolet Jun 17, 2022
092d428
Refactoring build_optimized_kmeans
achirkin Jun 17, 2022
fbcb16b
A few smaller refactorings for kmeans
achirkin Jun 17, 2022
29ca199
Add docs to public methods of the handle
achirkin Jun 20, 2022
38b3cec
Made the metric be a part of the index struct and set the greater_ = …
achirkin Jun 21, 2022
d19bb5f
Do not persist grid_dim_x between searches
achirkin Jun 21, 2022
9094707
Refactor names according to clang-tidy
achirkin Jun 21, 2022
325e201
Refactor the usage of stream and params
achirkin Jun 21, 2022
2a3eb33
Refactor api to have symmetric index/search params
achirkin Jun 21, 2022
867beca
refactor away ivf_flat_index
achirkin Jun 22, 2022
059a6c0
Add the memory resource argument to warp_sort_topk
achirkin Jun 22, 2022
df17b5b
update docs
achirkin Jun 22, 2022
fe9ced1
Allow empty mesoclusters
achirkin Jun 23, 2022
91fdcbb
Add low-dimensional and non-veclen-aligned-dimensional test cases
achirkin Jun 23, 2022
be14c63
Refactor and document loadAndComputeDist
achirkin Jun 23, 2022
eeb4601
Minor renamings
achirkin Jun 23, 2022
025e5a5
Add 8bit int types to knn benchmarks
achirkin Jun 23, 2022
3821366
Fix incorrect data mapping for int8 types
achirkin Jun 24, 2022
d596842
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jun 24, 2022
a29baa7
Introduce kIndexGroupSize constant
achirkin Jun 27, 2022
546bef8
Cleanup ann_quantized
achirkin Jun 27, 2022
32d0d2e
Add several type aliases and helpers for creating mdarrays
achirkin Jun 27, 2022
5f427c0
Remove unnecessary inlines and fix docs
achirkin Jun 28, 2022
c581fe2
More refactoring and a few forceinlines
achirkin Jun 28, 2022
805e78c
Add a helper for creating pool_memory_resource when it makes sense
achirkin Jun 29, 2022
a4973e6
Force move the mdarrays when creating index to avoid copying them
achirkin Jun 29, 2022
68c267e
Minor refactorings
achirkin Jun 29, 2022
f2b8ed8
Add nvtx annotations to the outermost ANN calls for better performanc…
achirkin Jun 29, 2022
f91c7f7
Add a few more test cases and annotations for them
achirkin Jun 29, 2022
84b1c5b
Fix a typo
achirkin Jun 29, 2022
afc1f6a
Move ensure_integral_extents to the detail folder
achirkin Jun 30, 2022
3a10f86
Lift the requirement to have query pointers aligned with Veclen
achirkin Jun 30, 2022
9f5c64c
Merge branch 'branch-22.08' into enh-mdarray-helpers
achirkin Jun 30, 2022
1afd667
Use move semantics for the index everywhere, but try to keep it const…
achirkin Jun 30, 2022
73ce9e1
Update documentation
achirkin Jun 30, 2022
2a45645
Remove the debug path USE_FAISS
achirkin Jun 30, 2022
75a48b4
Add a type trait for checking if the conversion between two numeric t…
achirkin Jul 1, 2022
ed25cae
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 1, 2022
388200c
Support 32bit and unsigned indices in bruteforce KNN
achirkin Jul 1, 2022
f08df83
Merge branch 'enh-mdarray-helpers' into fea-knn-ivf-flat
achirkin Jul 1, 2022
9200886
Merge branch 'enh-knn-bruteforce-uint32' into fea-knn-ivf-flat
achirkin Jul 1, 2022
14bfe02
Make index type a template parameter
achirkin Jul 1, 2022
1283cbe
Revert the api changes as much as possible and deprecate the old api
achirkin Jul 1, 2022
e73b259
Remove the stream argument from the public API
achirkin Jul 4, 2022
8e7ffb8
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
5f5dc0d
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 5, 2022
03ebbe0
Simplify kmeans::predict a little bit
achirkin Jul 6, 2022
cde7f97
Factor out predict from the other ops in kmeans for use outside of th…
achirkin Jul 7, 2022
305bbcd
Add new function extend(index, new_vecs, new_inds) to ivf_flat
achirkin Jul 20, 2022
76c383f
Merge branch 'branch-22.08' into fea-knn-ivf-flat
achirkin Jul 21, 2022
7f640a9
Improve the docs
achirkin Jul 21, 2022
2e9eda5
Fix using non-existing log function
achirkin Jul 21, 2022
dc62a0f
Hide all data components from ifv_flat::index and expose immutable views
achirkin Jul 21, 2022
fb841c3
Replace thurst::exclusive_scan with thrust::inclusive_scan to avoid a…
achirkin Jul 22, 2022
04bb5dc
Merge branch 'fea-knn-ivf-flat' into enh-knn-ivf-flat-hide-impl
achirkin Jul 22, 2022
c95ea85
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0c72ee8
ann_common.h: remove deps on cuda code, so that the file can be inclu…
achirkin Jul 22, 2022
0196695
Make helper overloads inline for linking in cuml
achirkin Jul 22, 2022
eb15639
Split processing.hpp into *.cuh and *.hpp to avoid incomplete types
achirkin Jul 22, 2022
e4b2b39
WIP: investigating segmentation fault in cuml test
achirkin Jul 25, 2022
6bc0fcb
Revert the wip-changes from the last commit
achirkin Jul 26, 2022
f599aaf
Merge remote-tracking branch 'origin/fea-knn-ivf-flat' into enh-knn-i…
achirkin Jul 26, 2022
a191410
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Jul 28, 2022
317ddf3
Enhance documentation
achirkin Jul 28, 2022
114fb63
Fix couple typos in docs
achirkin Jul 28, 2022
1d283ae
Change the data indexing to size_t to make sure the total size (size*…
achirkin Jul 28, 2022
a9bd2d6
Merge branch 'branch-22.08' into enh-knn-ivf-flat-hide-impl
achirkin Aug 2, 2022
f9d55a7
Make ivf_flat::index look a little bit more like knn::sparse api
achirkin Aug 2, 2022
fef6dac
Test both overloads of
achirkin Aug 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/bench/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct using_pool_memory_res {
private:
rmm::mr::device_memory_resource* orig_res_;
rmm::mr::cuda_memory_resource cuda_res_;
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_res_;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;

public:
using_pool_memory_res(size_t initial_size, size_t max_size)
Expand Down
82 changes: 18 additions & 64 deletions cpp/bench/spatial/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,22 @@
* limitations under the License.
*/

#include <optional>

#include <common/benchmark.hpp>
#include <raft/spatial/knn/ann.cuh>
#include <raft/spatial/knn/knn.cuh>

#include <raft/random/rng.cuh>
#include <raft/spatial/knn/knn.cuh>
#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.hpp>
#include <raft/spatial/knn/specializations.cuh>
#endif

#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>

#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>

#include <optional>

namespace raft::bench::spatial {

struct params {
Expand All @@ -41,14 +38,14 @@ struct params {
/** Number of dimensions in the dataset. */
size_t n_dims;
/** The batch size -- number of KNN searches. */
size_t n_probes;
size_t n_queries;
/** Number of nearest neighbours to find for every probe. */
size_t k;
};

auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_probes << "#" << p.k;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
return os;
}

Expand Down Expand Up @@ -129,44 +126,6 @@ struct host_uvector {
T* arr_;
};

template <typename ValT, typename IdxT>
struct ivf_flat_knn {
using dist_t = float;

raft::spatial::knn::knnIndex index;
raft::spatial::knn::ivf_flat::index_params index_params;
raft::spatial::knn::ivf_flat::search_params search_params;
params ps;

ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
raft::spatial::knn::approx_knn_build_index<ValT, IdxT>(const_cast<raft::handle_t&>(handle),
&index,
index_params,
const_cast<ValT*>(data),
(IdxT)ps.n_samples,
(IdxT)ps.n_dims);
}

void search(const raft::handle_t& handle,
const ValT* search_items,
dist_t* out_dists,
IdxT* out_idxs)
{
search_params.n_probes = 20;
raft::spatial::knn::approx_knn_search<ValT, IdxT>(const_cast<raft::handle_t&>(handle),
out_dists,
out_idxs,
&index,
search_params,
(IdxT)ps.k,
const_cast<ValT*>(search_items),
(IdxT)ps.n_probes);
}
};

template <typename ValT, typename IdxT>
struct brute_force_knn {
using dist_t = ValT;
Expand All @@ -191,7 +150,7 @@ struct brute_force_knn {
sizes,
ps.n_dims,
const_cast<ValT*>(search_items),
ps.n_probes,
ps.n_queries,
out_idxs,
out_dists,
ps.k);
Expand All @@ -206,9 +165,9 @@ struct knn : public fixture {
scope_(scope),
dev_mem_res_(strategy == TransferStrategy::MANAGED),
data_host_(0),
search_items_(p.n_probes * p.n_dims, stream),
out_dists_(p.n_probes * p.k, stream),
out_idxs_(p.n_probes * p.k, stream)
search_items_(p.n_queries * p.n_dims, stream),
out_dists_(p.n_queries * p.k, stream),
out_idxs_(p.n_queries * p.k, stream)
{
raft::random::RngState state{42};
gen_data(state, search_items_, search_items_.size(), stream);
Expand All @@ -233,9 +192,8 @@ struct knn : public fixture {
size_t n,
rmm::cuda_stream_view stream)
{
constexpr T kRangeMax = T(std::min<double>(
raft::spatial::knn::detail::utils::config<T>::kDivisor, std::numeric_limits<T>::max()));
constexpr T kRangeMin = std::is_signed_v<T> ? -kRangeMax : T(0);
constexpr T kRangeMax = std::is_integral_v<T> ? std::numeric_limits<T>::max() : T(1);
constexpr T kRangeMin = std::is_integral_v<T> ? std::numeric_limits<T>::min() : T(-1);
if constexpr (std::is_integral_v<T>) {
raft::random::uniformInt(state, vec.data(), n, kRangeMin, kRangeMax, stream);
} else {
Expand Down Expand Up @@ -352,15 +310,12 @@ struct knn : public fixture {
const std::vector<params> kInputs{
{2000000, 128, 1000, 32}, {10000000, 128, 1000, 32}, {10000, 8192, 1000, 32}};

const std::vector<TransferStrategy> kAllStrategies{TransferStrategy::NO_COPY,
TransferStrategy::COPY_PLAIN,
TransferStrategy::COPY_PINNED,
TransferStrategy::MAP_PINNED,
TransferStrategy::MANAGED};
const std::vector<TransferStrategy> kAllStrategies{
TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED};
const std::vector<TransferStrategy> kNoCopyOnly{TransferStrategy::NO_COPY};

const std::vector<Scope> kScopeFull{Scope::BUILD_SEARCH};
const std::vector<Scope> kAllScopes{Scope::BUILD, Scope::SEARCH, Scope::BUILD_SEARCH};
const std::vector<Scope> kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD};

#define KNN_REGISTER(ValT, IdxT, ImplT, inputs, strats, scope) \
namespace BENCHMARK_PRIVATE_NAME(knn) \
Expand All @@ -370,8 +325,7 @@ const std::vector<Scope> kAllScopes{Scope::BUILD, Scope::SEARCH, Scope::BUILD_SE
}

KNN_REGISTER(float, int64_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull);
KNN_REGISTER(float, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(int8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(uint8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes);

KNN_REGISTER(float, uint32_t, brute_force_knn, kInputs, kNoCopyOnly, kScopeFull);

} // namespace raft::bench::spatial
6 changes: 5 additions & 1 deletion cpp/include/raft/spatial/knn/detail/haversine_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ __global__ void haversine_knn_kernel(value_idx* out_inds,

faiss::gpu::
BlockSelect<value_t, value_idx, false, faiss::gpu::Comparator<value_t>, warp_q, thread_q, tpb>
heap(faiss::gpu::Limits<value_t>::getMax(), -1, smemK, smemV, k);
heap(faiss::gpu::Limits<value_t>::getMax(),
std::numeric_limits<value_idx>::max(),
smemK,
smemV,
k);

// Grid is exactly sized to rows available
int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize);
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ void brute_force_knn_impl(
int device;
RAFT_CUDA_TRY(cudaGetDevice(&device));

rmm::device_uvector<std::int64_t> trans(id_ranges->size(), userStream);
rmm::device_uvector<IdxType> trans(id_ranges->size(), userStream);
raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream);

rmm::device_uvector<value_t> all_D(0, userStream);
rmm::device_uvector<std::int64_t> all_I(0, userStream);
rmm::device_uvector<IdxType> all_I(0, userStream);

value_t* out_D = res_D;
IdxType* out_I = res_I;
Expand Down Expand Up @@ -342,6 +342,8 @@ void brute_force_knn_impl(
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;
args.outIndicesType = sizeof(IdxType) == 4 ? faiss::gpu::IndicesDataType::I32
: faiss::gpu::IndicesDataType::I64;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
Expand Down
32 changes: 32 additions & 0 deletions cpp/include/raft/spatial/knn/specializations/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,38 @@ extern template void brute_force_knn<long, float, unsigned int>(raft::handle_t c
std::vector<long>* translations,
distance::DistanceType metric,
float metric_arg);

extern template void brute_force_knn<uint32_t, float, int>(raft::handle_t const& handle,
std::vector<float*>& input,
std::vector<int>& sizes,
int D,
float* search_items,
int n,
uint32_t* res_I,
float* res_D,
int k,
bool rowMajorIndex,
bool rowMajorQuery,
std::vector<uint32_t>* translations,
distance::DistanceType metric,
float metric_arg);

extern template void brute_force_knn<uint32_t, float, unsigned int>(
raft::handle_t const& handle,
std::vector<float*>& input,
std::vector<unsigned int>& sizes,
unsigned int D,
float* search_items,
unsigned int n,
uint32_t* res_I,
float* res_D,
unsigned int k,
bool rowMajorIndex,
bool rowMajorQuery,
std::vector<uint32_t>* translations,
distance::DistanceType metric,
float metric_arg);

}; // namespace knn
}; // namespace spatial
}; // namespace raft
30 changes: 30 additions & 0 deletions cpp/src/nn/specializations/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,36 @@ template void brute_force_knn<long, float, unsigned int>(raft::handle_t const& h
distance::DistanceType metric,
float metric_arg);

template void brute_force_knn<uint32_t, float, int>(raft::handle_t const& handle,
std::vector<float*>& input,
std::vector<int>& sizes,
int D,
float* search_items,
int n,
uint32_t* res_I,
float* res_D,
int k,
bool rowMajorIndex,
bool rowMajorQuery,
std::vector<uint32_t>* translations,
distance::DistanceType metric,
float metric_arg);

template void brute_force_knn<uint32_t, float, unsigned int>(raft::handle_t const& handle,
std::vector<float*>& input,
std::vector<unsigned int>& sizes,
unsigned int D,
float* search_items,
unsigned int n,
uint32_t* res_I,
float* res_D,
unsigned int k,
bool rowMajorIndex,
bool rowMajorQuery,
std::vector<uint32_t>* translations,
distance::DistanceType metric,
float metric_arg);

}; // namespace knn
}; // namespace spatial
}; // namespace raft
24 changes: 15 additions & 9 deletions cpp/test/spatial/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "../test_utils.h"

#include <raft/core/logger.hpp>
#include <raft/distance/distance_type.hpp>

#include <raft/spatial/knn/knn.cuh>
#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.cuh>
Expand All @@ -40,8 +40,9 @@ struct KNNInputs {
std::vector<int> labels;
};

template <typename IdxT>
__global__ void build_actual_output(
int* output, int n_rows, int k, const int* idx_labels, const int64_t* indices)
int* output, int n_rows, int k, const int* idx_labels, const IdxT* indices)
{
int element = threadIdx.x + blockDim.x * blockIdx.x;
if (element >= n_rows * k) return;
Expand All @@ -60,7 +61,7 @@ __global__ void build_expected_output(int* output, int n_rows, int k, const int*
}
}

template <typename T>
template <typename T, typename IdxT>
class KNNTest : public ::testing::TestWithParam<KNNInputs> {
public:
KNNTest()
Expand All @@ -79,9 +80,11 @@ class KNNTest : public ::testing::TestWithParam<KNNInputs> {
protected:
void testBruteForce()
{
#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG)
raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout);
std::cout << "K: " << k_ << "\n";
std::cout << "K: " << k_ << std::endl;
raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout);
#endif

std::vector<float*> input_vec;
std::vector<int> sizes_vec;
Expand Down Expand Up @@ -131,7 +134,7 @@ class KNNTest : public ::testing::TestWithParam<KNNInputs> {
RAFT_CUDA_TRY(cudaMemsetAsync(input_.data(), 0, input_.size() * sizeof(float), stream));
RAFT_CUDA_TRY(
cudaMemsetAsync(search_data_.data(), 0, search_data_.size() * sizeof(float), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(indices_.data(), 0, indices_.size() * sizeof(int64_t), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(indices_.data(), 0, indices_.size() * sizeof(IdxT), stream));
RAFT_CUDA_TRY(cudaMemsetAsync(distances_.data(), 0, distances_.size() * sizeof(float), stream));
RAFT_CUDA_TRY(
cudaMemsetAsync(search_labels_.data(), 0, search_labels_.size() * sizeof(int), stream));
Expand Down Expand Up @@ -165,7 +168,7 @@ class KNNTest : public ::testing::TestWithParam<KNNInputs> {
int cols_;
rmm::device_uvector<float> input_;
rmm::device_uvector<float> search_data_;
rmm::device_uvector<int64_t> indices_;
rmm::device_uvector<IdxT> indices_;
rmm::device_uvector<float> distances_;
int k_;

Expand All @@ -191,10 +194,13 @@ const std::vector<KNNInputs> inputs = {
2,
{0, 0, 0, 0, 0, 1, 1, 1, 1, 1}}};

typedef KNNTest<float> KNNTestF;
TEST_P(KNNTestF, BruteForce) { this->testBruteForce(); }
typedef KNNTest<float, int64_t> KNNTestFint64_t;
TEST_P(KNNTestFint64_t, BruteForce) { this->testBruteForce(); }
typedef KNNTest<float, uint32_t> KNNTestFuint32_t;
TEST_P(KNNTestFuint32_t, BruteForce) { this->testBruteForce(); }

INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestF, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFint64_t, ::testing::ValuesIn(inputs));
INSTANTIATE_TEST_CASE_P(KNNTest, KNNTestFuint32_t, ::testing::ValuesIn(inputs));

} // namespace knn
} // namespace spatial
Expand Down