Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distance return for NN Descent #2345

Merged
merged 31 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
242d725
enable nn descent dist return
jinsolp May 29, 2024
0977947
change bool to int
jinsolp May 29, 2024
ea30245
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 29, 2024
a7749ac
add test for distance
jinsolp May 30, 2024
a753e45
test for indices and distances with one func
jinsolp May 30, 2024
1d96e17
fix styling
jinsolp May 30, 2024
66a7678
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 30, 2024
11379f7
change return_distances to bool
jinsolp May 31, 2024
e993d7c
change distances to optional
jinsolp May 31, 2024
4578dc0
fix styling:
jinsolp May 31, 2024
c9048a5
handle bad access error
jinsolp Jun 5, 2024
bdb74a5
remove unnecessary dist allocation
jinsolp Jun 5, 2024
c71e97a
change to device matrix
jinsolp Jun 5, 2024
0327ce5
change template param for index
jinsolp Jun 5, 2024
3f49752
update test
jinsolp Jun 6, 2024
50e2cf8
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 6, 2024
a6a8ad2
remove comment
jinsolp Jun 6, 2024
2cfda3c
fix styling
jinsolp Jun 6, 2024
2efae1a
fix header
jinsolp Jun 6, 2024
931158c
add documentation for return_distances
jinsolp Jun 6, 2024
9f17b5c
add documentation
jinsolp Jun 7, 2024
e33019c
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 7, 2024
e28c2f9
add tparam doc
jinsolp Jun 10, 2024
2d74f89
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 11, 2024
fd65442
remove redundancy
jinsolp Jun 12, 2024
3ce8cdf
return optional for distances()
jinsolp Jun 13, 2024
4ec3ab0
fix styling
jinsolp Jun 13, 2024
b0347a7
remove raft_expects
jinsolp Jun 13, 2024
84abd96
fix type template and bring back raft_expects
jinsolp Jun 13, 2024
75b82df
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 13, 2024
ec44b4d
fix raft_expects
jinsolp Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change template param for index
  • Loading branch information
jinsolp committed Jun 5, 2024
commit 0327ce591a246499676427d03abbc9b42528b5a1
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void build_knn_graph(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph,
experimental::nn_descent::index_params build_params)
{
auto nn_descent_idx = experimental::nn_descent::index<IdxT>(res, knn_graph);
auto nn_descent_idx = experimental::nn_descent::index<float, IdxT>(res, knn_graph);
experimental::nn_descent::build<DataT, IdxT>(res, build_params, dataset, nn_descent_idx);

using internal_IdxT = typename std::make_unsigned<IdxT>::type;
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ template <typename T,
void build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
index<IdxT>& idx)
index<DistData_t, IdxT>& idx)
{
RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits<int>::max() - 1,
"The dataset size for GNND should be less than %d",
Expand Down Expand Up @@ -1453,7 +1453,7 @@ template <typename T,
typename IdxT = uint32_t,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<IdxT> build(raft::resources const& res,
index<DistData_t, IdxT> build(raft::resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
Expand All @@ -1469,7 +1469,7 @@ index<IdxT> build(raft::resources const& res,
graph_degree = intermediate_degree;
}

index<IdxT> idx{
index<DistData_t, IdxT> idx{
res, dataset.extent(0), static_cast<int64_t>(graph_degree), params.return_distances};

build(res, params, dataset, idx);
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/neighbors/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace raft::neighbors::experimental::nn_descent {
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
index<IdxT> build(raft::resources const& res,
index<detail::DistData_t, IdxT> build(raft::resources const& res,
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
Expand Down Expand Up @@ -97,7 +97,7 @@ template <typename T, typename IdxT = uint32_t>
void build(raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<detail::DistData_t, IdxT>& idx)
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
{
detail::build<T, IdxT>(res, params, dataset, idx);
}
Expand Down Expand Up @@ -130,7 +130,7 @@ void build(raft::resources const& res,
* @return index<IdxT> index containing all-neighbors knn graph in host memory
*/
template <typename T, typename IdxT = uint32_t>
index<IdxT> build(raft::resources const& res,
index<detail::DistData_t, IdxT> build(raft::resources const& res,
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset)
{
Expand Down Expand Up @@ -171,7 +171,7 @@ template <typename T, typename IdxT = uint32_t>
void build(raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const T, int64_t, row_major> dataset,
index<IdxT>& idx)
index<detail::DistData_t, IdxT>& idx)
jinsolp marked this conversation as resolved.
Show resolved Hide resolved
{
detail::build<T, IdxT>(res, params, dataset, idx);
}
Expand Down
18 changes: 9 additions & 9 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <optional>

namespace raft::neighbors::experimental::nn_descent {
using DistData_t = float;
// using DistData_t = float;
/**
* @ingroup nn-descent
* @{
Expand Down Expand Up @@ -72,7 +72,7 @@ struct index_params : ann::index_params {
*
* @tparam IdxT dtype to be used for constructing knn-graph
*/
template <typename IdxT>
template <typename T, typename IdxT>
struct index : ann::index {
public:
/**
Expand All @@ -95,7 +95,7 @@ struct index : ann::index {
return_distances_(return_distances)
{
if (return_distances) {
distances_ = raft::make_device_matrix<DistData_t, int64_t>(res_, n_rows, n_cols);
distances_ = raft::make_device_matrix<T, int64_t>(res_, n_rows, n_cols);
distances_view_ = distances_.value().view();
}
}
Expand All @@ -112,14 +112,14 @@ struct index : ann::index {
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view =
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(0, 0)},
distances_{raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0)},
distances_{raft::make_device_matrix<T, int64_t>(res_, 0, 0)},
graph_view_{graph_view},
distances_view_(distances_view),
return_distances_(return_distances)
Expand Down Expand Up @@ -153,12 +153,12 @@ struct index : ann::index {
return graph_view_;
}

[[nodiscard]] inline auto distances() noexcept -> device_matrix_view<DistData_t, int64_t, row_major>
[[nodiscard]] inline auto distances() noexcept -> device_matrix_view<T, int64_t, row_major>
{
if (distances_view_.has_value()) {
return distances_view_.value();
} else {
return raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0).view();
return raft::make_device_matrix<T, int64_t>(res_, 0, 0).view();
}
}

Expand All @@ -173,10 +173,10 @@ struct index : ann::index {
raft::resources const& res_;
raft::distance::DistanceType metric_;
raft::host_matrix<IdxT, int64_t, row_major> graph_; // graph to return for non-int IdxT
std::optional<raft::device_matrix<DistData_t, int64_t, row_major>> distances_;
std::optional<raft::device_matrix<T, int64_t, row_major>> distances_;
raft::host_matrix_view<IdxT, int64_t, row_major>
graph_view_; // view of graph for user provided matrix
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view_;
std::optional<raft::device_matrix_view<T, int64_t, row_major>> distances_view_;
bool return_distances_;
};

Expand Down