Skip to content

Commit

Permalink
Move ANN to RAFT (additional updates) (#270)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Brad Rees (https://github.com/BradReesWork)

URL: #270
  • Loading branch information
cjnolet authored Jun 10, 2021
1 parent 6c02b59 commit 73417b2
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 34 deletions.
3 changes: 3 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

#include "../ann_common.h"

#include "common_faiss.h"
#include "processing.hpp"

#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>

Expand Down
67 changes: 67 additions & 0 deletions cpp/include/raft/spatial/knn/detail/common_faiss.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/cudart_utils.h>
#include <raft/cuda_utils.cuh>

#include <faiss/gpu/GpuDistance.h>
#include <raft/linalg/distance_type.h>

namespace raft {
namespace spatial {
namespace knn {
namespace detail {

inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
THROW("MetricType not supported: %d", metric);
}
}

} // namespace detail
} // namespace knn
} // namespace spatial
} // namespace raft
36 changes: 2 additions & 34 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "haversine_distance.cuh"
#include "processing.hpp"

#include "common_faiss.h"

namespace raft {
namespace spatial {
namespace knn {
Expand Down Expand Up @@ -167,40 +169,6 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK,
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
THROW("MetricType not supported: %d", metric);
}
}

/**
* Search the kNN for the k-nearest neighbors of a set of query vectors
* @param[in] input vector of device device memory array pointers to search
Expand Down

0 comments on commit 73417b2

Please sign in to comment.