Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable distance return for NN Descent #2345

Merged
merged 31 commits into from
Jun 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
242d725
enable nn descent dist return
jinsolp May 29, 2024
0977947
change bool to int
jinsolp May 29, 2024
ea30245
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 29, 2024
a7749ac
add test for distance
jinsolp May 30, 2024
a753e45
test for indices and distances with one func
jinsolp May 30, 2024
1d96e17
fix styling
jinsolp May 30, 2024
66a7678
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp May 30, 2024
11379f7
change return_distances to bool
jinsolp May 31, 2024
e993d7c
change distances to optional
jinsolp May 31, 2024
4578dc0
fix styling:
jinsolp May 31, 2024
c9048a5
handle bad access error
jinsolp Jun 5, 2024
bdb74a5
remove unnecessary dist allocation
jinsolp Jun 5, 2024
c71e97a
change to device matrix
jinsolp Jun 5, 2024
0327ce5
change template param for index
jinsolp Jun 5, 2024
3f49752
update test
jinsolp Jun 6, 2024
50e2cf8
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 6, 2024
a6a8ad2
remove comment
jinsolp Jun 6, 2024
2cfda3c
fix styling
jinsolp Jun 6, 2024
2efae1a
fix header
jinsolp Jun 6, 2024
931158c
add documentation for return_distances
jinsolp Jun 6, 2024
9f17b5c
add documentation
jinsolp Jun 7, 2024
e33019c
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 7, 2024
e28c2f9
add tparam doc
jinsolp Jun 10, 2024
2d74f89
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 11, 2024
fd65442
remove redundancy
jinsolp Jun 12, 2024
3ce8cdf
return optional for distances()
jinsolp Jun 13, 2024
4ec3ab0
fix styling
jinsolp Jun 13, 2024
b0347a7
remove raft_expects
jinsolp Jun 13, 2024
84abd96
fix type template and bring back raft_expects
jinsolp Jun 13, 2024
75b82df
Merge branch 'rapidsai:branch-24.08' into fea-nndescent-api
jinsolp Jun 13, 2024
ec44b4d
fix raft_expects
jinsolp Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change to device matrix
  • Loading branch information
jinsolp committed Jun 5, 2024
commit c71e97a646bc6916a83cfe0fd25a6d2fee215db8
16 changes: 9 additions & 7 deletions cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
Expand Down Expand Up @@ -93,7 +95,7 @@ struct index : ann::index {
return_distances_(return_distances)
{
if (return_distances) {
distances_ = raft::make_host_matrix<DistData_t, int64_t, row_major>(n_rows, n_cols);
distances_ = raft::make_device_matrix<DistData_t, int64_t>(res_, n_rows, n_cols);
distances_view_ = distances_.value().view();
}
}
Expand All @@ -110,14 +112,14 @@ struct index : ann::index {
*/
index(raft::resources const& res,
raft::host_matrix_view<IdxT, int64_t, raft::row_major> graph_view,
std::optional<raft::host_matrix_view<DistData_t, int64_t, raft::row_major>> distances_view =
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view =
std::nullopt,
bool return_distances = false)
: ann::index(),
res_{res},
metric_{raft::distance::DistanceType::L2Expanded},
graph_{raft::make_host_matrix<IdxT, int64_t, row_major>(0, 0)},
distances_{raft::make_host_matrix<DistData_t, int64_t, row_major>(0, 0)},
distances_{raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0)},
graph_view_{graph_view},
distances_view_(distances_view),
return_distances_(return_distances)
Expand Down Expand Up @@ -151,12 +153,12 @@ struct index : ann::index {
return graph_view_;
}

[[nodiscard]] inline auto distances() noexcept -> host_matrix_view<DistData_t, int64_t, row_major>
[[nodiscard]] inline auto distances() noexcept -> device_matrix_view<DistData_t, int64_t, row_major>
{
if (distances_view_.has_value()) {
return distances_view_.value();
} else {
return raft::make_host_matrix<DistData_t, int64_t, row_major>(0, 0).view();
return raft::make_device_matrix<DistData_t, int64_t>(res_, 0, 0).view();
}
}

Expand All @@ -171,10 +173,10 @@ struct index : ann::index {
raft::resources const& res_;
raft::distance::DistanceType metric_;
raft::host_matrix<IdxT, int64_t, row_major> graph_; // graph to return for non-int IdxT
std::optional<raft::host_matrix<DistData_t, int64_t, row_major>> distances_;
std::optional<raft::device_matrix<DistData_t, int64_t, row_major>> distances_;
raft::host_matrix_view<IdxT, int64_t, row_major>
graph_view_; // view of graph for user provided matrix
std::optional<raft::host_matrix_view<DistData_t, int64_t, row_major>> distances_view_;
std::optional<raft::device_matrix_view<DistData_t, int64_t, row_major>> distances_view_;
bool return_distances_;
};

Expand Down