Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher #939

Merged
merged 28 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
20648c5
cutlass based euclidean expanded, cosine kernels
mdoijade Oct 20, 2022
a9dabc8
add prior ampere pairwisedistmat kernel to prevent redundant kernel c…
mdoijade Oct 21, 2022
1a45bfa
add noexcept to the functor methods
mdoijade Oct 21, 2022
c6f091b
merge branch 22.12 and resolve conflicts
mdoijade Oct 21, 2022
7786fcb
fix comments, remove redundant code and fix formatting issues
mdoijade Oct 27, 2022
181fc40
add cutlass cmake support for raft with custom namespace, fix formati…
mdoijade Oct 28, 2022
3d34545
fix formatting issues
mdoijade Oct 28, 2022
02c23ed
fix the cutlass_include_dir path in cmake
mdoijade Nov 3, 2022
7933436
fix bugs in get_cutlass cmake to use cutlass provided properties corr…
mdoijade Nov 4, 2022
d4bdec5
remove the cutlass namespace setting in test cmakefiles as it is not …
mdoijade Nov 4, 2022
d26bcef
temp remove dist dependency from cutlass to check if it works in ci/cd
mdoijade Nov 4, 2022
4df4185
merge branch-22.12 latest changes
mdoijade Nov 7, 2022
451c3c0
fix get_cutlass.cmake to work with pylibraft by using NvidiaCutlass i…
mdoijade Nov 10, 2022
7b512f9
fix get_cutlass install path, make changes as per review comments
mdoijade Nov 10, 2022
a05e1e2
merge branch-22.12
mdoijade Nov 10, 2022
d32b4c0
fix clang format issues
mdoijade Nov 10, 2022
f7c440a
temp fix to check if python build works
mdoijade Nov 11, 2022
b1a1fd7
add raft-exports instead of raft-distance-exports as other raft compo…
mdoijade Nov 15, 2022
4ef44e7
make cutlass to depend only on raft_distance and add raft_distance de…
mdoijade Nov 16, 2022
186fcc7
fix cmake formatting issues
mdoijade Nov 16, 2022
8aa8909
prevent cutlass based pairwise dist kernels to be disabled on cuda 12…
mdoijade Nov 16, 2022
abfd493
Moving cutlass dependency to distance and nn to keep them separate.
cjnolet Nov 16, 2022
f1b1239
Adding CUTLASS to build docs as dependency
cjnolet Nov 16, 2022
32e6052
Updating to export to both distance and nn
cjnolet Nov 16, 2022
f6de9ee
Adding cutlass as private dependency
cjnolet Nov 16, 2022
9bf0647
Making cutlass INTERFACE in raft::nn and raft::distance
cjnolet Nov 16, 2022
8f0119a
Using proper exports per Robert Maynard's suggestion.
cjnolet Nov 16, 2022
6ad4fd1
Adding cutlass as private dependency of lib targets
cjnolet Nov 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix get_cutlass install path, make changes as per review comments
  • Loading branch information
mdoijade committed Nov 10, 2022
commit 7b512f99bbfffee488ad329062787b017ba2f087
12 changes: 7 additions & 5 deletions cpp/cmake/thirdparty/get_cutlass.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ function(find_and_configure_cutlass)
set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "Disable CUTLASS to build with cuBLAS library.")

rapids_cpm_find(NvidiaCutlass ${PKG_VERSION}
GLOBAL_TARGETS nvidia::cutlass::cutlass
GLOBAL_TARGETS nvidia::cutlass::cutlass
CPM_ARGS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
OPTIONS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
GIT_SHALLOW TRUE
OPTIONS
"CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}"
)

Expand All @@ -51,7 +52,8 @@ function(find_and_configure_cutlass)

# Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote.
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports)
rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports)
rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports)
endfunction()

if(NOT RAFT_CUTLASS_GIT_TAG)
Expand Down
6 changes: 5 additions & 1 deletion cpp/include/raft/distance/detail/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void cosineImpl(const DataT* x,
FinalLambda fin_op,
cudaStream_t stream)
{
const auto deviceVersion = getMajorMinorVersion();
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
using CosineOp_ = CosineOp<DataT, AccT>;
CosineOp_ cosine_dist_op;
Expand Down Expand Up @@ -228,6 +228,10 @@ void cosineAlgo1(Index_ m,
{
auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); };

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
typedef typename std::conditional<sizeof(OutType) == 1, OutType, AccType>::type CosOutType;
mdoijade marked this conversation as resolved.
Show resolved Hide resolved
CosOutType* pDcast = reinterpret_cast<CosOutType*>(pD);

Expand Down
4 changes: 4 additions & 0 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,10 @@ void distance(const InType* x,
using final_op_type = default_fin_op<AccType, OutType, Index_>;
final_op_type fin_op;

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
distance<distanceType, InType, AccType, OutType, final_op_type, Index_>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg);
RAFT_CUDA_TRY(cudaPeekAtLastError());
Expand Down
6 changes: 5 additions & 1 deletion cpp/include/raft/distance/detail/euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void euclideanExpImpl(const DataT* x,
FinalLambda fin_op,
cudaStream_t stream)
{
const auto deviceVersion = getMajorMinorVersion();
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
using L2Op = L2ExpandedOp<DataT, AccT>;
L2Op L2_dist_op(sqrt);
Expand Down Expand Up @@ -245,6 +245,10 @@ void euclideanAlgo1(Index_ m,
{
auto norm_op = [] __device__(InType in) { return in; };

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(! ((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
typedef typename std::conditional<sizeof(OutType) == 1, OutType, AccType>::type ExpOutType;
mdoijade marked this conversation as resolved.
Show resolved Hide resolved
ExpOutType* pDcast = reinterpret_cast<ExpOutType*>(pD);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"

// We define CUTLASS_NAMESPACE in case
// RAFT cmake is not used
#ifndef CUTLASS_NAMESPACE
#define cutlass raft_cutlass
#endif


#include <rmm/device_uvector.hpp>

#include <cutlass/cutlass.h>
Expand Down
11 changes: 8 additions & 3 deletions cpp/include/raft/distance/detail/pairwise_distance_epilogue.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.

The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.

This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0
(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75)

This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec
and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise operation.
-- A norm load is provided PredicatedTileIteratorNormVec
-- B norm load is provided by EpilogueWithBroadcast
-- elementwise operation is provided by OutputOp
*/

#pragma once
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
/*! \file
\brief Functor performing distance operations used by epilogues of pairwise distance
* kernels.
* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0
* customized for applying elementwise distance formula on accumulated GEMM value
* and applying user-defined final custom operation on the distance value.
*/

#pragma once
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/pairwise_distance_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct PairwiseDistanceGemm {

// This code section describes how threadblocks are scheduled on GPU
/// Threadblock-level swizzling operator
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;

/// data layout for final output matrix.
// we keep this same layout even for column major inputs
Expand Down Expand Up @@ -179,7 +179,7 @@ struct PairwiseDistanceGemm<double,

// This code section describes how threadblocks are scheduled on GPU
/// Threadblock-level swizzling operator
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;

/// data layout for final output matrix.
// we keep this same layout even for column major inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.

mdoijade marked this conversation as resolved.
Show resolved Hide resolved
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0
(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75)

Changes:
- added `Layout_` template param
- Only the row index is used to load the data in load_with_byte_offset().
This way the same normalization data is used across all columns in a row.

*/

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/util/cudart_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ inline int getMultiProcessorCount()
}

/** helper method to get major minor compute capability version */
inline std::pair<int, int> getMajorMinorVersion()
inline std::pair<int, int> getComputeCapability()
{
int devId;
RAFT_CUDA_TRY(cudaGetDevice(&devId));
Expand Down
4 changes: 4 additions & 0 deletions cpp/test/distance/dist_adj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class DistanceAdjTest : public ::testing::TestWithParam<DistanceAdjInputs<DataTy

protected:
DistanceAdjInputs<DataType> params;
// We use uint8_t even if the output in this test is a bool because
// cutlass doesn't support bool as output buffer yet. In cuda
// sizeof(bool) is 1 byte hence it doesn't increase
// memory consumption if we use uint8_t instead of bool.
rmm::device_uvector<uint8_t> dist_ref;
mdoijade marked this conversation as resolved.
Show resolved Hide resolved
rmm::device_uvector<uint8_t> dist;
raft::handle_t handle;
Expand Down