Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark brute force knn #2063

Merged
merged 12 commits into from
Dec 20, 2023
Next Next commit
Benchmark brute force knn
Add our bfknn code to the raft-ann-bench project
  • Loading branch information
benfred committed Dec 13, 2023
commit 49970f4e81fe5703329245080dbda76da4da4887
5 changes: 2 additions & 3 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
option(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT
"Include faiss' cpu brute-force knn algorithm in benchmark" ON
)
option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT "Include faiss' cpu brute-force algorithm in benchmark" ON)

option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algorithm in benchmark"
Expand All @@ -30,6 +27,7 @@ option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm
option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON)
option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON)
option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON)
option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON)
Expand All @@ -55,6 +53,7 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF)
set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
else()
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <nlohmann/json.hpp>

#undef WARP_SIZE
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
#include "raft_wrapper.h"
#endif
#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
Expand Down
9 changes: 6 additions & 3 deletions cpp/bench/ann/src/raft/raft_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cmath>
#include <memory>
#include <raft/core/logger.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <stdexcept>
#include <string>
Expand All @@ -47,8 +48,10 @@ std::unique_ptr<raft::bench::ann::ANN<T>> create_algo(const std::string& algo,
std::unique_ptr<raft::bench::ann::ANN<T>> ann;

if constexpr (std::is_same_v<T, float>) {
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
if (algo == "raft_bfknn") { ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim); }
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
ann = std::make_unique<raft::bench::ann::RaftGpu<T>>(metric, dim);
}
#endif
}

Expand Down Expand Up @@ -85,7 +88,7 @@ template <typename T>
std::unique_ptr<typename raft::bench::ann::ANN<T>::AnnSearchParam> create_search_param(
const std::string& algo, const nlohmann::json& conf)
{
#ifdef RAFT_ANN_BENCH_USE_RAFT_BFKNN
#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
if (algo == "raft_brute_force") {
auto param = std::make_unique<typename raft::bench::ann::ANN<T>::AnnSearchParam>();
return param;
Expand Down
71 changes: 33 additions & 38 deletions cpp/bench/ann/src/raft/raft_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

#include <cassert>
#include <memory>
#include <raft/core/device_resources.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/neighbors/brute_force.cuh>
#include <raft/neighbors/brute_force_serialize.cuh>
#include <stdexcept>
#include <string>
#include <type_traits>
Expand All @@ -30,20 +32,16 @@ namespace raft_temp {

inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric)
{
if (metric == raft::bench::ann::Metric::kInnerProduct) {
return raft::distance::DistanceType::InnerProduct;
} else if (metric == raft::bench::ann::Metric::kEuclidean) {
return raft::distance::DistanceType::L2Expanded;
} else {
throw std::runtime_error("raft supports only metric type of inner product and L2");
switch (metric) {
case raft::bench::ann::Metric::kInnerProduct: return raft::distance::DistanceType::InnerProduct;
case raft::bench::ann::Metric::kEuclidean: return raft::distance::DistanceType::L2Expanded;
}
}

} // namespace raft_temp

namespace raft::bench::ann {

// brute force fused L2 KNN - RAFT
// brute force KNN - RAFT
template <typename T>
class RaftGpu : public ANN<T> {
public:
Expand Down Expand Up @@ -74,9 +72,11 @@ class RaftGpu : public ANN<T> {
}
void set_search_dataset(const T* dataset, size_t nrow) override;
void save(const std::string& file) const override;
void load(const std::string&) override { return; };
void load(const std::string&) override;

protected:
raft::device_resources handle_;
std::optional<raft::neighbors::brute_force::index<T>> index_;
raft::distance::DistanceType metric_type_;
int device_;
const T* dataset_;
Expand All @@ -85,17 +85,18 @@ class RaftGpu : public ANN<T> {

template <typename T>
RaftGpu<T>::RaftGpu(Metric metric, int dim)
: ANN<T>(metric, dim), metric_type_(raft_temp::parse_metric_type(metric))
: ANN<T>(metric, dim),
metric_type_(raft_temp::parse_metric_type(metric)),
handle_(cudaStreamPerThread)
{
static_assert(std::is_same_v<T, float>, "raft support only float type");
assert(metric_type_ == raft::distance::DistanceType::L2Expanded);
RAFT_CUDA_TRY(cudaGetDevice(&device_));
}

template <typename T>
void RaftGpu<T>::build(const T*, size_t, cudaStream_t)
void RaftGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t)
{
// as this is brute force algo so no index building required
auto dataset_view = raft::make_host_matrix_view<const T, int64_t>(dataset, nrow, this->dim_);
index_.emplace(raft::neighbors::brute_force::build(handle_, dataset_view));
return;
}

Expand All @@ -115,15 +116,13 @@ void RaftGpu<T>::set_search_dataset(const T* dataset, size_t nrow)
template <typename T>
void RaftGpu<T>::save(const std::string& file) const
{
// create a empty index file as no index to store.
std::fstream fp;
fp.open(file.c_str(), std::ios::out);
if (!fp) {
printf("Error in creating file!!!\n");
;
return;
}
fp.close();
raft::neighbors::brute_force::serialize<T>(handle_, file, *index_);
}

template <typename T>
void RaftGpu<T>::load(const std::string& file)
{
index_ = raft::neighbors::brute_force::deserialize<T>(handle_, file);
}

template <typename T>
Expand All @@ -134,20 +133,16 @@ void RaftGpu<T>::search(const T* queries,
float* distances,
cudaStream_t stream) const
{
// TODO: Integrate new `raft::brute_force::index` (from
// https://github.com/rapidsai/raft/pull/1817)
raft::spatial::knn::detail::fusedL2Knn(this->dim_,
reinterpret_cast<int64_t*>(neighbors),
distances,
dataset_,
queries,
nrow_,
static_cast<size_t>(batch_size),
k,
true,
true,
stream,
metric_type_);
auto queries_view =
raft::make_device_matrix_view<const T, int64_t>(queries, batch_size, this->dim_);

auto neighbors_view = raft::make_device_matrix_view<size_t, int64_t>(neighbors, batch_size, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, batch_size, k);

raft::neighbors::brute_force::search<T, size_t>(
handle_, *index_, queries_view, neighbors_view, distances_view);

handle_.sync_stream();
}

} // namespace raft::bench::ann
Loading