Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher #939

Merged
merged 28 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
20648c5
cutlass based euclidean expanded, cosine kernels
mdoijade Oct 20, 2022
a9dabc8
add prior ampere pairwisedistmat kernel to prevent redundant kernel c…
mdoijade Oct 21, 2022
1a45bfa
add noexcept to the functor methods
mdoijade Oct 21, 2022
c6f091b
merge branch 22.12 and resolve conflicts
mdoijade Oct 21, 2022
7786fcb
fix comments, remove redundant code and fix formatting issues
mdoijade Oct 27, 2022
181fc40
add cutlass cmake support for raft with custom namespace, fix formati…
mdoijade Oct 28, 2022
3d34545
fix formatting issues
mdoijade Oct 28, 2022
02c23ed
fix the cutlass_include_dir path in cmake
mdoijade Nov 3, 2022
7933436
fix bugs in get_cutlass cmake to use cutlass provided properties corr…
mdoijade Nov 4, 2022
d4bdec5
remove the cutlass namespace setting in test cmakefiles as it is not …
mdoijade Nov 4, 2022
d26bcef
temp remove dist dependency from cutlass to check if it works in ci/cd
mdoijade Nov 4, 2022
4df4185
merge branch-22.12 latest changes
mdoijade Nov 7, 2022
451c3c0
fix get_cutlass.cmake to work with pylibraft by using NvidiaCutlass i…
mdoijade Nov 10, 2022
7b512f9
fix get_cutlass install path, make changes as per review comments
mdoijade Nov 10, 2022
a05e1e2
merge branch-22.12
mdoijade Nov 10, 2022
d32b4c0
fix clang format issues
mdoijade Nov 10, 2022
f7c440a
temp fix to check if python build works
mdoijade Nov 11, 2022
b1a1fd7
add raft-exports instead of raft-distance-exports as other raft compo…
mdoijade Nov 15, 2022
4ef44e7
make cutlass to depend only on raft_distance and add raft_distance de…
mdoijade Nov 16, 2022
186fcc7
fix cmake formatting issues
mdoijade Nov 16, 2022
8aa8909
prevent cutlass based pairwise dist kernels to be disabled on cuda 12…
mdoijade Nov 16, 2022
abfd493
Moving cutlass dependency to distance and nn to keep them separate.
cjnolet Nov 16, 2022
f1b1239
Adding CUTLASS to build docs as dependency
cjnolet Nov 16, 2022
32e6052
Updating to export to both distance and nn
cjnolet Nov 16, 2022
f6de9ee
Adding cutlass as private dependency
cjnolet Nov 16, 2022
9bf0647
Making cutlass INTERFACE in raft::nn and raft::distance
cjnolet Nov 16, 2022
8f0119a
Using proper exports per Robert Maynard's suggestion.
cjnolet Nov 16, 2022
6ad4fd1
Adding cutlass as private dependency of lib targets
cjnolet Nov 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ rapids_cpm_init()
include(cmake/thirdparty/get_thrust.cmake)
include(cmake/thirdparty/get_rmm.cmake)
include(cmake/thirdparty/get_faiss.cmake)
include(cmake/thirdparty/get_cutlass.cmake)

if(RAFT_ENABLE_cuco_DEPENDENCY)
include(${rapids-cmake-dir}/cpm/cuco.cmake)
Expand Down Expand Up @@ -217,6 +218,7 @@ target_link_libraries(
CUDA::cusolver${_ctk_static_suffix}
CUDA::cusparse${_ctk_static_suffix}
$<$<BOOL:${RAFT_ENABLE_thrust_DEPENDENCY}>:raft::Thrust>
nvidia::cutlass::cutlass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency needs to be on the raft-distance target and not the raft target

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertmaynard I see that several tests and core algorithms depend on cosine/euclidean distance headers where we are using cutlass, hence I think it needs to be dependency on raft target. without it I am seeing several build failures when those sources are built.
I've modified get_cutlass.cmake from raft-distance-exports to raft-exports. Can this resolve the build issue in pylibraft?
I've submitted this change and waiting to see if CI passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet Are you okay with cutlass being a hard requirement for raft?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we could make CUTLASS a dependency only of raft::distance (which pylibraft uses).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdoijade any tests/benchmarks and downstream projects which depend on distances also specify and use the raft::distance target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I've tried to add dependency on raft_distance wherever required and got the build working locally, please review if it looks good now.
also I've enabled the cutlass path only till cuda 11.x.

)

target_compile_features(raft INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
Expand Down Expand Up @@ -588,6 +590,9 @@ string(
[=[
if(distance IN_LIST raft_FIND_COMPONENTS)
enable_language(CUDA)
if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass)
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
endif()
endif()

if(nn IN_LIST raft_FIND_COMPONENTS)
Expand Down
69 changes: 69 additions & 0 deletions cpp/cmake/thirdparty/get_cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#=============================================================================
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#=============================================================================

function(find_and_configure_cutlass)
set(oneValueArgs VERSION REPOSITORY PINNED_TAG)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

#if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES)
set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
set(CUTLASS_NAMESPACE "raft_cutlass" CACHE STRING "Top level namespace of CUTLASS")
set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "Disable CUTLASS to build with cuBLAS library.")

rapids_cpm_find(NvidiaCutlass ${PKG_VERSION}
GLOBAL_TARGETS nvidia::cutlass::cutlass
CPM_ARGS
GIT_REPOSITORY ${PKG_REPOSITORY}
GIT_TAG ${PKG_PINNED_TAG}
GIT_SHALLOW TRUE
OPTIONS
"CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}"
)

if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass)
add_library(nvidia::cutlass::cutlass ALIAS CUTLASS)
endif()

if(NvidiaCutlass_ADDED)
rapids_export(BUILD NvidiaCutlass
EXPORT_SET NvidiaCutlass
GLOBAL_TARGETS nvidia::cutlass::cutlass
NAMESPACE nvidia::cutlass::)
endif()
#endif()

# We generate the cutlass-config files when we built cutlass locally, so always do `find_dependency`
rapids_export_package(BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass)
rapids_export_package(INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass)

# Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote.
include("${rapids-cmake-dir}/export/find_package_root.cmake")
rapids_export_find_package_root(INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports)
rapids_export_find_package_root(BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports)
endfunction()

if(NOT RAFT_CUTLASS_GIT_TAG)
set(RAFT_CUTLASS_GIT_TAG v2.9.1)
endif()

if(NOT RAFT_CUTLASS_GIT_REPOSITORY)
set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git)
endif()

find_and_configure_cutlass(VERSION 2.9.1
REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY}
PINNED_TAG ${RAFT_CUTLASS_GIT_TAG})
135 changes: 77 additions & 58 deletions cpp/include/raft/distance/detail/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@
#pragma once

#include <raft/distance/detail/pairwise_distance_base.cuh>
#include <raft/distance/detail/pairwise_distance_cutlass_base.cuh>
#include <raft/linalg/norm.cuh>

namespace raft {
namespace distance {
namespace detail {

template <typename DataT, typename AccT>
struct CosineOp {
__device__ CosineOp() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - (AccT)(accVal / (aNorm * bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
};

/**
* @brief the cosine distance matrix calculation implementer
* It computes the following equation:
Expand Down Expand Up @@ -71,61 +82,71 @@ void cosineImpl(const DataT* x,
FinalLambda fin_op,
cudaStream_t stream)
{
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;
const auto deviceVersion = getComputeCapability();
if (deviceVersion.first >= 8) {
using CosineOp_ = CosineOp<DataT, AccT>;
CosineOp_ cosine_dist_op;

cutlassDistanceKernel<DataT, AccT, OutT, IdxT, VecLen, FinalLambda, CosineOp_, isRowMajor>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream);

} else {
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;

typedef typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;
typedef typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;

dim3 blk(KPolicy::Nthreads);
dim3 blk(KPolicy::Nthreads);

// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; };
// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; };

// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn,
DataT * regyn,
IdxT gridStrideX,
IdxT gridStrideY) {
// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn,
DataT * regyn,
IdxT gridStrideX,
IdxT gridStrideY) {
#pragma unroll
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]);
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j]));
}
}
}
};
};

constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
auto cosineRowMajor = pairwiseDistanceMatKernel<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineRowMajor);
cosineRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
} else {
auto cosineColMajor = pairwiseDistanceMatKernel<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineColMajor);
cosineColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
auto cosineRowMajor = pairwiseDistanceMatKernelPriorToAmpere<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineRowMajor);
cosineRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
} else {
auto cosineColMajor = pairwiseDistanceMatKernelPriorToAmpere<true,
DataT,
AccT,
OutT,
IdxT,
KPolicy,
decltype(core_lambda),
decltype(epilog_lambda),
FinalLambda,
false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineColMajor);
cosineColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op);
}
}

RAFT_CUDA_TRY(cudaGetLastError());
Expand Down Expand Up @@ -207,13 +228,11 @@ void cosineAlgo1(Index_ m,
{
auto norm_op = [] __device__(AccType in) { return raft::mySqrt(in); };

// Wrap fin_op to allow computing 1 - pA before calling fin_op
auto wrapped_fin_op = [fin_op] __device__(AccType d_val, Index_ g_d_idx) {
return fin_op(static_cast<AccType>(1.0) - d_val, g_d_idx);
};

typedef std::is_same<OutType, bool> is_bool;
typedef typename std::conditional<is_bool::value, OutType, AccType>::type CosOutType;
// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
typedef typename std::conditional<sizeof(OutType) == 1, OutType, AccType>::type CosOutType;
mdoijade marked this conversation as resolved.
Show resolved Hide resolved
CosOutType* pDcast = reinterpret_cast<CosOutType*>(pD);

ASSERT(
Expand All @@ -234,12 +253,12 @@ void cosineAlgo1(Index_ m,

if (isRowMajor) {
lda = k, ldb = k, ldd = n;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op), true>(
m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, wrapped_fin_op, stream);
cosine<InType, AccType, CosOutType, Index_, FinalLambda, true>(
m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, fin_op, stream);
} else {
lda = n, ldb = m, ldd = m;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op), false>(
n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, wrapped_fin_op, stream);
cosine<InType, AccType, CosOutType, Index_, FinalLambda, false>(
n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, fin_op, stream);
}
}

Expand Down
25 changes: 22 additions & 3 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,19 @@ void distance(const InType* x,
* @note if workspace is passed as nullptr, this will return in
* worksize, the number of bytes of workspace required
*/

// Default final op functor which facilitates elementwise operation on
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, then even the previous overload of distance requires its FinalLambda to be such a functor. Could you amend the docstring at line 568 accordingly?

// final distance value if any.
template <typename AccType, typename OutType, typename Index>
struct default_fin_op {
__host__ __device__ default_fin_op() noexcept {};
// functor signature.
__host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept
{
return d_val;
}
};

template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
Expand All @@ -632,9 +645,15 @@ void distance(const InType* x,
bool isRowMajor = true,
InType metric_arg = 2.0f)
{
auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { return d_val; };
distance<distanceType, InType, AccType, OutType, decltype(default_fin_op), Index_>(
x, y, dist, m, n, k, workspace, worksize, default_fin_op, stream, isRowMajor, metric_arg);
using final_op_type = default_fin_op<AccType, OutType, Index_>;
final_op_type fin_op;

// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))),
"OutType can be uint8_t, float, double,"
"if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType).");
distance<distanceType, InType, AccType, OutType, final_op_type, Index_>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

Expand Down
Loading