Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding fused_l2_nn_argmin wrapper to Pylibraft #924

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)
if(RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
Expand Down
58 changes: 58 additions & 0 deletions cpp/include/raft_distance/fused_l2_min_arg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>

namespace raft::distance::runtime {

/**
* @brief Wrapper around fusedL2NN with minimum reduction operators.
*
* fusedL2NN cannot be compiled in the distance library due to the lambda
* operators, so this wrapper covers the most common case (minimum).
*
* @param[in] handle raft handle
* @param[out] min will contain the reduced output (Length = `m`)
* (on device)
* @param[in] x first matrix. Row major. Dim = `m x k`.
* (on device).
* @param[in] y second matrix. Row major. Dim = `n x k`.
* (on device).
* @param[in] m gemm m
* @param[in] n gemm n
* @param[in] k gemm k
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt
*/
void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt);

void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt);

} // end namespace raft::distance::runtime
98 changes: 98 additions & 0 deletions cpp/src/distance/fused_l2_min_arg.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/device_mdarray.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/kvp.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/distance/specializations.cuh>
#include <thrust/for_each.h>
#include <thrust/tuple.h>

namespace raft::distance::runtime {

template <typename IndexT, typename DataT>
struct KeyValueIndexOp {
__host__ __device__ __forceinline__ IndexT
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const
{
return a.key;
}
};

template <typename value_t, typename idx_t>
void compute_fused_l2_nn_min_arg(raft::handle_t const& handle,
idx_t* min,
const value_t* x,
const value_t* y,
idx_t m,
idx_t n,
idx_t k,
bool sqrt)
{
rmm::device_uvector<int> workspace(m, handle.get_stream());
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m);

rmm::device_uvector<value_t> x_norms(m, handle.get_stream());
rmm::device_uvector<value_t> y_norms(n, handle.get_stream());
raft::linalg::rowNorm(x_norms.data(), x, k, m, raft::linalg::L2Norm, true, handle.get_stream());
raft::linalg::rowNorm(y_norms.data(), y, k, n, raft::linalg::L2Norm, true, handle.get_stream());

fusedL2NNMinReduce(kvp.data_handle(),
x,
y,
x_norms.data(),
y_norms.data(),
m,
n,
k,
(void*)workspace.data(),
sqrt,
true,
handle.get_stream());

KeyValueIndexOp<idx_t, value_t> conversion_op;
thrust::transform(
handle.get_thrust_policy(), kvp.data_handle(), kvp.data_handle() + m, min, conversion_op);
handle.sync_stream();
}

void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt)
{
compute_fused_l2_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt);
}

void fused_l2_nn_min_arg(raft::handle_t const& handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt)
{
compute_fused_l2_nn_min_arg<double, int>(handle, min, x, y, m, n, k, sqrt);
}

} // end namespace raft::distance::runtime
3 changes: 2 additions & 1 deletion python/pylibraft/pylibraft/distance/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources pairwise_distance.pyx)
set(cython_sources pairwise_distance.pyx
fused_l2_nn.pyx)
set(linked_libraries raft::raft raft::distance)

# Build all of the Cython targets
Expand Down
1 change: 1 addition & 0 deletions python/pylibraft/pylibraft/distance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.
#

from .fused_l2_nn import fused_l2_nn_argmin
from .pairwise_distance import distance as pairwise_distance
150 changes: 150 additions & 0 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np

from libc.stdint cimport uintptr_t
from cython.operator cimport dereference as deref

from libcpp cimport bool
from .distance_type cimport DistanceType
from pylibraft.common.handle cimport handle_t


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize


cdef extern from "raft_distance/fused_l2_min_arg.hpp" \
namespace "raft::distance::runtime":

void fused_l2_nn_min_arg(
const handle_t &handle,
int* min,
const float* x,
const float* y,
int m,
int n,
int k,
bool sqrt)

void fused_l2_nn_min_arg(
const handle_t &handle,
int* min,
const double* x,
const double* y,
int m,
int n,
int k,
bool sqrt)


def fused_l2_nn_argmin(X, Y, output, sqrt=True):
"""
Compute the 1-nearest neighbors between X and Y using the L2 distance

Parameters
----------

X : CUDA array interface compliant matrix shape (m, k)
Y : CUDA array interface compliant matrix shape (n, k)
output : Writable CUDA array interface matrix shape (m, 1)

Examples
--------

.. code-block:: python

import cupy as cp

from pylibraft.distance import fused_l2_nn

n_samples = 5000
n_clusters = 5
n_features = 50

in1 = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
in2 = cp.random.random_sample((n_clusters, n_features),
dtype=cp.float32)
output = cp.empty((n_samples, 1), dtype=cp.int32)

fused_l2_nn_argmin(in1, in2, output)
"""

x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
output_cai = output.__cuda_array_interface__

m = x_cai["shape"][0]
n = y_cai["shape"][0]

x_k = x_cai["shape"][1]
y_k = y_cai["shape"][1]

if x_k != y_k:
raise ValueError("Inputs must have same number of columns. "
"a=%s, b=%s" % (x_k, y_k))

x_ptr = <uintptr_t>x_cai["data"][0]
y_ptr = <uintptr_t>y_cai["data"][0]

d_ptr = <uintptr_t>output_cai["data"][0]

cdef handle_t *h = new handle_t()

x_dt = np.dtype(x_cai["typestr"])
y_dt = np.dtype(y_cai["typestr"])
d_dt = np.dtype(output_cai["typestr"])

x_c_contiguous = is_c_cont(x_cai, x_dt)
y_c_contiguous = is_c_cont(y_cai, y_dt)

if x_c_contiguous != y_c_contiguous:
raise ValueError("Inputs must have matching strides")

print(x_dt)
if x_dt != y_dt:
raise ValueError("Inputs must have the same dtypes")
if d_dt != np.int32:
raise ValueError("Output array must be int32")

if x_dt == np.float32:
fused_l2_nn_min_arg(deref(h),
<int*> d_ptr,
<float*> x_ptr,
<float*> y_ptr,
<int>m,
<int>n,
<int>x_k,
<bool>sqrt)
elif x_dt == np.float64:
fused_l2_nn_min_arg(deref(h),
<int*> d_ptr,
<double*> x_ptr,
<double*> y_ptr,
<int>m,
<int>n,
<int>x_k,
<bool>sqrt)
else:
raise ValueError("dtype %s not supported" % x_dt)
47 changes: 47 additions & 0 deletions python/pylibraft/pylibraft/test/test_fused_l2_argmin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from scipy.spatial.distance import cdist
import pytest
import numpy as np

from pylibraft.distance import fused_l2_nn_argmin
from pylibraft.testing.utils import TestDeviceBuffer


@pytest.mark.parametrize("n_rows", [10, 100])
@pytest.mark.parametrize("n_clusters", [5, 10])
@pytest.mark.parametrize("n_cols", [3, 5])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_fused_l2_nn_minarg(n_rows, n_cols, n_clusters, dtype):
input1 = np.random.random_sample((n_rows, n_cols))
input1 = np.asarray(input1, order="C").astype(dtype)

input2 = np.random.random_sample((n_clusters, n_cols))
input2 = np.asarray(input2, order="C").astype(dtype)

output = np.zeros((n_rows), dtype="int32")
expected = cdist(input1, input2, metric="euclidean")

expected = expected.argmin(axis=1)

input1_device = TestDeviceBuffer(input1, "C")
input2_device = TestDeviceBuffer(input2, "C")
output_device = TestDeviceBuffer(output, "C")

fused_l2_nn_argmin(input1_device, input2_device, output_device, True)
actual = output_device.copy_to_host()

assert np.allclose(expected, actual, rtol=1e-4)
1 change: 1 addition & 0 deletions python/pylibraft/pylibraft/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
class TestDeviceBuffer:

def __init__(self, ndarray, order):

self.ndarray_ = ndarray
self.device_buffer_ = \
rmm.DeviceBuffer.to_device(ndarray.ravel(order=order).tobytes())
Expand Down