Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to enable HDBSCAN #208

Merged
merged 40 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
046f703
Allowing epilogue in knn graph connection function
cjnolet Apr 20, 2021
086b08a
Fixing style
cjnolet Apr 20, 2021
37d2e0d
Adding missing argument
cjnolet Apr 23, 2021
709c040
Fixing typename
cjnolet Apr 23, 2021
b1fbc63
Updates
cjnolet Apr 23, 2021
81b630a
Fixing style
cjnolet Apr 23, 2021
6c52542
Updating to get hdbscan to compile
cjnolet Apr 23, 2021
4be3d24
Some updates
cjnolet Apr 26, 2021
1aaa7d5
changes
cjnolet Apr 29, 2021
dd6e537
agglomerative labeling to accept device arrays directly
cjnolet Apr 30, 2021
e3085cb
Cleaning up inputs to some of the single linkage prims
cjnolet Apr 30, 2021
ec1cca5
removing deprecated rmm policy usage
divyegala May 4, 2021
d32677d
Changes
cjnolet May 5, 2021
43f7cf8
removing deprecated rmm policy usage
divyegala May 4, 2021
e0ef5b3
alpha to weight alteration for precision
divyegala May 10, 2021
d809630
merge
divyegala May 10, 2021
1678c66
resolving errors
divyegala May 10, 2021
8018ea5
merge again
divyegala May 10, 2021
404e5ce
Merge branch 'fea-020-hdbscan' of github.com:cjnolet/raft into fea-02…
divyegala May 10, 2021
5eda4aa
trying alpha for all
divyegala May 11, 2021
030b3a9
double precision weight alteration
divyegala May 12, 2021
635d018
Checking in
cjnolet May 14, 2021
26dff4d
Fixing style
cjnolet May 19, 2021
55f274c
Removing unused epilogue from mst
cjnolet May 19, 2021
eb69413
Removing mst epilogue functor
cjnolet May 19, 2021
9f39d69
Merge branch 'branch-0.20' into fea-020-hdbscan
cjnolet May 19, 2021
dfebf10
Merge branch 'branch-21.06' into fea-020-hdbscan
cjnolet May 19, 2021
50d1cdc
Getting test to build
cjnolet May 20, 2021
c654156
Removing mstepiloguenoop since it's no longer being used
cjnolet May 22, 2021
136529a
another template param for weight alteration
divyegala May 24, 2021
e4b0f91
merge upstream
divyegala May 24, 2021
d7e93d9
renaming confusing variable name
divyegala May 24, 2021
e51d08b
working through merge
divyegala May 26, 2021
beea020
merging mst template PR
divyegala May 26, 2021
86cbf42
removing unnecessary comments
divyegala May 26, 2021
eb92e26
Review feedback
cjnolet May 27, 2021
e337c0d
Merge branch 'branch-21.06' into fea-020-hdbscan
cjnolet May 27, 2021
8b1e344
Update cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh
cjnolet May 27, 2021
1933520
Fixing bad merge
cjnolet May 27, 2021
1bc3e68
Merge branch 'fea-020-hdbscan' of github.com:cjnolet/raft into fea-02…
cjnolet May 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
alpha to weight alteration for precision
  • Loading branch information
divyegala committed May 10, 2021
commit e0ef5b33ca647bb9337e48ee4568f3759b0fe017
18 changes: 14 additions & 4 deletions cpp/include/raft/sparse/mst/detail/mst_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,27 @@ template <typename vertex_t, typename edge_t, typename weight_t>
__global__ void alteration_kernel(const vertex_t v, const edge_t e,
const edge_t* offsets,
const vertex_t* indices,
const weight_t* weights, weight_t max,
const weight_t* weights, double max,
weight_t* random_values,
weight_t* altered_weights) {
weight_t* altered_weights, int alpha,
bool use_alpha) {
auto row = get_1D_idx<vertex_t>();
if (row < v) {
auto row_begin = offsets[row];
auto row_end = offsets[row + 1];
for (auto i = row_begin; i < row_end; i++) {
auto column = indices[i];
altered_weights[i] =
weights[i] + max * (random_values[row] + random_values[column]);
// doing the later step explicity in double for precision
if (use_alpha) {
altered_weights[i] =
alpha * weights[i] + alpha * max *
(static_cast<double>(random_values[row]) +
static_cast<double>(random_values[column]));
} else {
altered_weights[i] =
weights[i] + max * (static_cast<double>(random_values[row]) +
static_cast<double>(random_values[column]));
}
}
}
}
Expand Down
23 changes: 16 additions & 7 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
bool symmetrize_output_, bool initialize_colors_, int iterations_)
bool symmetrize_output_, bool initialize_colors_, int iterations_, int alpha_)
: handle(handle_),
offsets(offsets_),
indices(indices_),
Expand All @@ -75,7 +75,8 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
stream(stream_),
symmetrize_output(symmetrize_output_),
initialize_colors(initialize_colors_),
iterations(iterations_) {
iterations(iterations_),
alpha(alpha_) {
max_blocks = handle_.get_device_properties().maxGridSize[0];
max_threads = handle_.get_device_properties().maxThreadsPerBlock;
sm_count = handle_.get_device_properties().multiProcessorCount;
Expand Down Expand Up @@ -114,7 +115,9 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
timer0 = duration_us(stop - start);
#endif

Graph_COO<vertex_t, edge_t, weight_t> mst_result(2 * v - 2, stream);
auto n_expected_edges = symmetrize_output ? 2 * v - 2 : v - 1;

Graph_COO<vertex_t, edge_t, weight_t> mst_result(n_expected_edges, stream);

// Boruvka original formulation says "while more than 1 supervertex remains"
// Here we adjust it to support disconnected components (spanning forest)
Expand Down Expand Up @@ -150,6 +153,10 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
timer3 += duration_us(stop - start);
#endif

RAFT_EXPECTS(mst_edge_count[0] == n_expected_edges,
"Number of edges found by MST is invalid. This may be due to "
"loss in precision. Try increasing precision of weights.")

if (prev_mst_edge_count[0] == mst_edge_count[0]) {
#ifdef MST_TIME
std::cout << "Iterations: " << i << std::endl;
Expand Down Expand Up @@ -209,7 +216,7 @@ struct alteration_functor {

// Compute the uper bound for the alteration
template <typename vertex_t, typename edge_t, typename weight_t>
weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
double MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
auto policy = rmm::exec_policy(stream);
rmm::device_vector<weight_t> tmp(e);
thrust::device_ptr<const weight_t> weights_ptr(weights);
Expand All @@ -229,7 +236,7 @@ weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
auto max =
thrust::transform_reduce(policy, begin, end, alteration_functor<weight_t>(),
init, thrust::minimum<weight_t>());
return max / static_cast<weight_t>(2);
return max / static_cast<double>(2);
}

// Compute the alteration to make all undirected edge weight unique
Expand All @@ -240,7 +247,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::alteration() {
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

// maximum alteration that does not change realtive weights order
weight_t max = alteration_max();
double max = alteration_max();

// pool of rand values
rmm::device_vector<weight_t> rand_values(v);
Expand All @@ -258,10 +265,12 @@ void MST_solver<vertex_t, edge_t, weight_t>::alteration() {
RAFT_EXPECTS(curand_status == CURAND_STATUS_SUCCESS,
"MST: CURAND cleanup failed");

bool use_alpha = max < 1e-3 && sizeof(weight_t) == 4;

//Alterate the weights, make all undirected edge weight unique while keeping Wuv == Wvu
detail::alteration_kernel<<<nblocks, nthreads, 0, stream>>>(
v, e, offsets, indices, weights, max, rand_values.data().get(),
altered_weights.data().get());
altered_weights.data().get(), alpha, use_alpha);
}

// updates colors of vertices by propagating the lower color to the higher
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream, bool symmetrize_output = true,
bool initialize_colors = true, int iterations = 0) {
bool initialize_colors = true, int iterations = 0, int alpha = 1e6) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output,
initialize_colors, iterations);
Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class MST_solver {
const vertex_t* indices_, const weight_t* weights_,
const vertex_t v_, const edge_t e_, vertex_t* color_,
cudaStream_t stream_, bool symmetrize_output_,
bool initialize_colors_, int iterations_);
bool initialize_colors_, int iterations_ int alpha_);

raft::Graph_COO<vertex_t, edge_t, weight_t> solve();

Expand All @@ -54,6 +54,7 @@ class MST_solver {
cudaStream_t stream;
bool symmetrize_output, initialize_colors;
int iterations;
int alpha;

//CSR
const edge_t* offsets;
Expand Down Expand Up @@ -90,7 +91,7 @@ class MST_solver {
void min_edge_per_supervertex();
void check_termination();
void alteration();
weight_t alteration_max();
double alteration_max();
void append_src_dst_pair(vertex_t* mst_src, vertex_t* mst_dst,
weight_t* mst_weights);
};
Expand Down