1919#include < raft/cudart_utils.h>
2020#include < raft/cuda_utils.cuh>
2121
22+ #include < raft/sparse/op/sort.h>
2223#include < raft/mr/device/buffer.hpp>
2324#include < raft/sparse/mst/mst.cuh>
2425#include < raft/sparse/selection/connect_components.cuh>
@@ -35,29 +36,6 @@ namespace raft {
3536namespace hierarchy {
3637namespace detail {
3738
38- /* *
39- * Sorts a COO by its weight
40- * @tparam value_idx
41- * @tparam value_t
42- * @param[inout] rows source edges
43- * @param[inout] cols dest edges
44- * @param[inout] data edge weights
45- * @param[in] nnz number of edges in edge list
46- * @param[in] stream cuda stream for which to order cuda operations
47- */
48- template <typename value_idx, typename value_t >
49- void sort_coo_by_data (value_idx *rows, value_idx *cols, value_t *data,
50- value_idx nnz, cudaStream_t stream) {
51- thrust::device_ptr<value_idx> t_rows = thrust::device_pointer_cast (rows);
52- thrust::device_ptr<value_idx> t_cols = thrust::device_pointer_cast (cols);
53- thrust::device_ptr<value_t > t_data = thrust::device_pointer_cast (data);
54-
55- auto first = thrust::make_zip_iterator (thrust::make_tuple (rows, cols));
56-
57- thrust::sort_by_key (thrust::cuda::par.on (stream), t_data, t_data + nnz,
58- first);
59- }
60-
6139template <typename value_idx, typename value_t >
6240void merge_msts (raft::Graph_COO<value_idx, value_idx, value_t > &coo1,
6341 raft::Graph_COO<value_idx, value_idx, value_t > &coo2,
@@ -95,19 +73,20 @@ void merge_msts(raft::Graph_COO<value_idx, value_idx, value_t> &coo1,
9573 * @param[inout] color the color labels array returned from the mst invocation
9674 * @return updated MST edge list
9775 */
98- template <typename value_idx, typename value_t >
76+ template <typename value_idx, typename value_t , typename red_op >
9977void connect_knn_graph (const raft::handle_t &handle, const value_t *X,
10078 raft::Graph_COO<value_idx, value_idx, value_t > &msf,
10179 size_t m, size_t n, value_idx *color,
80+ red_op reduction_op,
10281 raft::distance::DistanceType metric =
10382 raft::distance::DistanceType::L2SqrtExpanded) {
10483 auto d_alloc = handle.get_device_allocator ();
10584 auto stream = handle.get_stream ();
10685
10786 raft::sparse::COO<value_t , value_idx> connected_edges (d_alloc, stream);
10887
109- raft::linkage::connect_components<value_idx, value_t >(handle, connected_edges,
110- X, color, m, n);
88+ raft::linkage::connect_components<value_idx, value_t >(
89+ handle, connected_edges, X, color, m, n, reduction_op );
11190
11291 rmm::device_uvector<value_idx> indptr2 (m + 1 , stream);
11392 raft::sparse::convert::sorted_coo_to_csr (connected_edges.rows (),
@@ -147,38 +126,34 @@ void connect_knn_graph(const raft::handle_t &handle, const value_t *X,
147126 * @param[in] max_iter maximum iterations to run knn graph connection. This
148127 * argument is really just a safeguard against the potential for infinite loops.
149128 */
150- template <typename value_idx, typename value_t >
129+ template <typename value_idx, typename value_t , typename red_op >
151130void build_sorted_mst (const raft::handle_t &handle, const value_t *X,
152131 const value_idx *indptr, const value_idx *indices,
153132 const value_t *pw_dists, size_t m, size_t n,
154- rmm::device_uvector<value_idx> &mst_src,
155- rmm::device_uvector<value_idx> &mst_dst,
156- rmm::device_uvector<value_t > &mst_weight,
157- const size_t nnz,
133+ value_idx *mst_src, value_idx *mst_dst,
134+ value_t *mst_weight, value_idx *color, size_t nnz,
135+ red_op reduction_op,
158136 raft::distance::DistanceType metric =
159137 raft::distance::DistanceType::L2SqrtExpanded,
160138 int max_iter = 10 ) {
161139 auto d_alloc = handle.get_device_allocator ();
162140 auto stream = handle.get_stream ();
163141
164- rmm::device_uvector<value_idx> color (m, stream);
165-
166142 // We want to have MST initialize colors on first call.
167143 auto mst_coo = raft::mst::mst<value_idx, value_idx, value_t , double >(
168- handle, indptr, indices, pw_dists, (value_idx)m, nnz, color. data () , stream,
169- false , true );
144+ handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false ,
145+ true );
170146
171147 int iters = 1 ;
172- int n_components =
173- linkage::get_n_components (color.data (), m, d_alloc, stream);
148+ int n_components = linkage::get_n_components (color, m, d_alloc, stream);
174149
175150 while (n_components > 1 && iters < max_iter) {
176- connect_knn_graph<value_idx, value_t >(handle, X, mst_coo, m, n,
177- color. data () );
151+ connect_knn_graph<value_idx, value_t >(handle, X, mst_coo, m, n, color,
152+ reduction_op );
178153
179154 iters++;
180155
181- n_components = linkage::get_n_components (color. data () , m, d_alloc, stream);
156+ n_components = linkage::get_n_components (color, m, d_alloc, stream);
182157 }
183158
184159 /* *
@@ -189,7 +164,7 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X,
189164 * 1. There is a bug in this code somewhere
190165 * 2. Either the given KNN graph wasn't generated from X or the same metric is not being used
191166 * to generate the 1-nn (currently only L2SqrtExpanded is supported).
192- * 3. max_iter was not large enough to connect the graph.
167+ * 3. max_iter was not large enough to connect the graph (less likely) .
193168 *
194169 * Note that a KNN graph generated from 50 random isotropic balls (with significant overlap)
195170 * was able to be connected in a single iteration.
@@ -201,20 +176,15 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X,
201176 " or increase 'max_iter'" ,
202177 max_iter);
203178
204- sort_coo_by_data (mst_coo.src .data (), mst_coo.dst .data (),
205- mst_coo.weights .data (), mst_coo.n_edges , stream);
206-
207- // TODO: be nice if we could pass these directly into the MST
208- mst_src.resize (mst_coo.n_edges , stream);
209- mst_dst.resize (mst_coo.n_edges , stream);
210- mst_weight.resize (mst_coo.n_edges , stream);
179+ raft::sparse::op::coo_sort_by_weight (mst_coo.src .data (), mst_coo.dst .data (),
180+ mst_coo.weights .data (), mst_coo.n_edges ,
181+ stream);
211182
212- raft::copy_async (mst_src.data (), mst_coo.src .data (), mst_coo.n_edges , stream);
213- raft::copy_async (mst_dst.data (), mst_coo.dst .data (), mst_coo.n_edges , stream);
214- raft::copy_async (mst_weight.data (), mst_coo.weights .data (), mst_coo.n_edges ,
215- stream);
183+ raft::copy_async (mst_src, mst_coo.src .data (), mst_coo.n_edges , stream);
184+ raft::copy_async (mst_dst, mst_coo.dst .data (), mst_coo.n_edges , stream);
185+ raft::copy_async (mst_weight, mst_coo.weights .data (), mst_coo.n_edges , stream);
216186}
217187
218188}; // namespace detail
219189}; // namespace hierarchy
220- }; // namespace raft
190+ }; // namespace raft
0 commit comments