Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding fused_l2_nn_argmin wrapper to Pylibraft #924

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleaning up style and prints
  • Loading branch information
cjnolet committed Oct 18, 2022
commit b7a58355d8b28f0302f38725fc61cf038ee35a53
39 changes: 18 additions & 21 deletions cpp/include/raft_distance/fused_l2_min_arg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <raft/distance/distance_types.hpp>

namespace raft::distance::runtime {
/**
/**
* @brief Wrapper around fusedL2NN with minimum reduction operators.
*
* fusedL2NN cannot be compiled in the distance library due to the lambda
Expand All @@ -36,25 +36,22 @@ namespace raft::distance::runtime {
* @param[in] n gemm n
* @param[in] k gemm k
*/
void fused_l2_nn_min_arg(
raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt);
void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt);

void fused_l2_nn_min_arg(
raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt);
void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt);


} // end namespace raft::distance::runtime
} // end namespace raft::distance::runtime
135 changes: 68 additions & 67 deletions cpp/src/distance/fused_l2_min_arg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,64 @@
* limitations under the License.
*/

#include <raft/distance/fused_l2_nn.cuh>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/kvp.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/specializations.cuh>
#include <raft/core/device_mdarray.hpp>
#include <thrust/for_each.h>
#include <thrust/tuple.h>
#include <raft/core/kvp.hpp>
#include <raft/core/handle.hpp>

namespace raft::distance::runtime {

template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
printf("%d, %f\n", a.key, a.value);
return a.key;
}
};
template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
};

template<typename value_t, typename idx_t>
void compute_fused_l2_nn_min_arg(
raft::handle_t const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt) {
rmm::device_uvector<int> workspace(m, handle.get_stream());
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);
template <typename value_t, typename idx_t>
void compute_fused_l2_nn_min_arg(raft::handle_t const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt)
{
rmm::device_uvector<int> workspace(m, handle.get_stream());
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);

rmm::device_uvector<value_t> x_norms(m, handle.get_stream());
rmm::device_uvector<value_t> y_norms(n, handle.get_stream());
raft::linalg::rowNorm(x_norms.data(), x, k, m, raft::linalg::L2Norm, true, handle.get_stream());
raft::linalg::rowNorm(y_norms.data(), y, k, n, raft::linalg::L2Norm, true, handle.get_stream());
rmm::device_uvector<value_t> x_norms(m, handle.get_stream());
rmm::device_uvector<value_t> y_norms(n, handle.get_stream());
raft::linalg::rowNorm(x_norms.data(), x, k, m, raft::linalg::L2Norm, true, handle.get_stream());
raft::linalg::rowNorm(y_norms.data(), y, k, n, raft::linalg::L2Norm, true, handle.get_stream());

fusedL2NNMinReduce(kvp.data_handle(), x, y, x_norms.data(), y_norms.data(), m, n, k, (void*)workspace.data(), sqrt, true, handle.get_stream());
fusedL2NNMinReduce(kvp.data_handle(),
x,
y,
x_norms.data(),
y_norms.data(),
m,
n,
k,
(void*)workspace.data(),
sqrt,
true,
handle.get_stream());

raft::print_device_vector("x", x, m*k, std::cout);
raft::print_device_vector("y", y, n*k, std::cout);

raft::print_device_vector("x_norms", x_norms.data(), m, std::cout);
raft::print_device_vector("y_norms", y_norms.data(), n, std::cout);

KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(handle.get_thrust_policy(), kvp.data_handle(), kvp.data_handle()+m, min, conversion_op);
handle.sync_stream();
raft::print_device_vector("min", min, m, std::cout);
}
KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(
handle.get_thrust_policy(), kvp.data_handle(), kvp.data_handle() + m, min, conversion_op);
handle.sync_stream();
}

/**
/**
* @brief Wrapper around fusedL2NN with minimum reduction operators.
*
* fusedL2NN cannot be compiled in the distance library due to the lambda
Expand All @@ -87,31 +91,28 @@ template<typename value_t, typename idx_t>
* @param[in] n gemm n
* @param[in] k gemm k
*/
void fused_l2_nn_min_arg(
raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt) {

compute_fused_l2_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt);
}

void fused_l2_nn_min_arg(
raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt) {

compute_fused_l2_nn_min_arg<double, int>(handle, min, x, y, m, n, k, sqrt);
void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt)
{
compute_fused_l2_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt);
}

void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt)
{
compute_fused_l2_nn_min_arg<double, int>(handle, min, x, y, m, n, k, sqrt);
}

} // end namespace raft::distance::runtime
} // end namespace raft::distance::runtime
4 changes: 0 additions & 4 deletions python/pylibraft/pylibraft/test/test_fused_l2_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ def test_fused_l2_nn_minarg(n_rows, n_cols, n_clusters, dtype):
output_device = TestDeviceBuffer(output, "C")

fused_l2_nn_argmin(input1_device, input2_device, output_device, True)

actual = output_device.copy_to_host()

print(str(expected))

print(str(actual))
assert np.allclose(expected, actual, rtol=1e-4)