Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support for half-float mixed precise in brute-force #2382

Merged
merged 20 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 45 additions & 31 deletions cpp/include/raft/core/detail/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/linalg/detail/cublas_wrappers.hpp>
#ifdef __CUDACC__
#include <raft/linalg/transpose.cuh>
#include <raft/util/cuda_dev_essentials.cuh>
#endif
#endif
Expand Down Expand Up @@ -449,38 +450,51 @@ mdspan_copyable_t<DstType, SrcType> copy(resources const& res, DstType&& dst, Sr
#endif
} else if constexpr (config::can_use_cublas) {
#ifndef RAFT_DISABLE_CUDA
auto constexpr const alpha = typename std::remove_reference_t<DstType>::value_type{1};
auto constexpr const beta = typename std::remove_reference_t<DstType>::value_type{0};
if constexpr (std::is_same_v<typename config::dst_layout_type, layout_c_contiguous>) {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(1),
dst.extent(0),
&alpha,
src.data_handle(),
src.extent(0),
&beta,
dst.data_handle(),
dst.extent(1),
dst.data_handle(),
dst.extent(1),
resource::get_cuda_stream(res)));
if constexpr (!((std::is_same_v<typename std::remove_reference_t<DstType>::value_type, half>)&&(
std::is_same_v<typename std::remove_reference_t<SrcType>::value_type, half>))) {
auto constexpr const alpha = typename std::remove_reference_t<DstType>::value_type{1};
auto constexpr const beta = typename std::remove_reference_t<DstType>::value_type{0};
if constexpr (std::is_same_v<typename config::dst_layout_type, layout_c_contiguous>) {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(1),
dst.extent(0),
&alpha,
src.data_handle(),
src.extent(0),
&beta,
dst.data_handle(),
dst.extent(1),
dst.data_handle(),
dst.extent(1),
resource::get_cuda_stream(res)));
} else {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(0),
dst.extent(1),
&alpha,
src.data_handle(),
src.extent(1),
&beta,
dst.data_handle(),
dst.extent(0),
dst.data_handle(),
dst.extent(0),
resource::get_cuda_stream(res)));
}
} else {
CUBLAS_TRY(linalg::detail::cublasgeam(resource::get_cublas_handle(res),
CUBLAS_OP_T,
CUBLAS_OP_N,
dst.extent(0),
dst.extent(1),
&alpha,
src.data_handle(),
src.extent(1),
&beta,
dst.data_handle(),
dst.extent(0),
dst.data_handle(),
dst.extent(0),
resource::get_cuda_stream(res)));
#ifdef __CUDACC__
raft::linalg::transpose(res, dst, src);
#else
// Should never actually reach this because of enable_ifs. Included for
// safety.
RAFT_FAIL(
"raft::copy called in a way that requires custom kernel. Please use "
"raft/core/copy.cuh and include the header in a .cu file");
#endif
}
#else
// Not possible to reach this due to enable_ifs. Included for safety.
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/core/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ template <typename T>
RAFT_INLINE_FUNCTION auto asin(T x)
{
#ifdef __CUDA_ARCH__
return ::asin(x);
if constexpr (std::is_same<T, __half>::value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need half support for the asin function?

I'm wondering if we should either remove the half support for this function, or add half support for all the other trigonometric functions in this file

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe we can. As of now, it's a kind of trade-off. I understand your meaning; some distance algorithms need this as supporting half; removing it will cause a compilation error. If we bring half everything, that could be ideal, but the workload can be out of control..

float x_float = __half2float(x);
float result_float = ::asin(x_float);
return __float2half(result_float);
} else {
return ::asin(x);
}
#else
return std::asin(x);
#endif
Expand Down Expand Up @@ -333,6 +339,12 @@ RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y)
((std::is_same_v<T1, float> || std::is_same_v<T1, double>)&&(
std::is_same_v<T2, float> || std::is_same_v<T2, double>))) {
return ::max(x, y);
} else if constexpr (std::is_same_v<T1, float> && std::is_same_v<T2, __half>) {
const float f_y = __half2float(y);
return (x < f_y) ? f_y : x;
} else if constexpr (std::is_same_v<T1, __half> && std::is_same_v<T2, float>) {
const float f_x = __half2float(x);
return (f_x < y) ? y : f_x;
}
// Else, check that the types are the same and provide a generic implementation
else {
Expand Down
18 changes: 17 additions & 1 deletion cpp/include/raft/core/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <raft/core/detail/macros.hpp>
#include <raft/core/math.hpp>

#include <cuda_fp16.h>

#include <algorithm>
#include <cmath>
#include <tuple>
Expand Down Expand Up @@ -104,13 +106,27 @@ struct sq_op {
{
return in * in;
}

template <typename... UnusedArgs>
constexpr RAFT_INLINE_FUNCTION auto operator()(const half& in, UnusedArgs...) const
{
return __half2float(in) * __half2float(in);
}
};

struct add_op {
template <typename T1, typename T2>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const
{
return a + b;
if constexpr (std::is_same_v<T1, half> && std::is_same_v<T2, half>) {
return __half2float(a) + __half2float(b);
} else if constexpr (std::is_same_v<T1, half>) {
return __half2float(a) + b;
} else if constexpr (std::is_same_v<T2, half>) {
return a + __half2float(b);
} else {
return a + b;
}
}
};

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/masked_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -266,7 +266,7 @@ struct MaskedDistances : public BaseClass {
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
acc[i][j] = BaseClass::Zero();
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -200,7 +200,7 @@ struct PairwiseDistances : public BaseClass {
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
acc[i][j] = BaseClass::Zero();
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/linalg/contractions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -164,6 +164,12 @@ struct Policy4x4<double, _veclen> {
typedef KernelPolicy<double, _veclen, 16, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<double, _veclen, 16, 4, 4, 16, 16> ColPolicy;
};

template <int _veclen>
struct Policy4x4<half, _veclen> {
typedef KernelPolicy<half, _veclen, 64, 4, 4, 16, 16> Policy;
typedef ColKernelPolicy<half, _veclen, 64, 4, 4, 16, 16> ColPolicy;
};
/** @} */

/**
Expand Down Expand Up @@ -204,6 +210,12 @@ struct Policy2x8<double, _veclen> {
// this is not used just for keeping compiler happy.
typedef KernelPolicy<double, _veclen, 32, 1, 2, 8, 32> Policy;
};

template <int _veclen>
struct Policy2x8<half, _veclen> {
typedef KernelPolicy<half, _veclen, 16, 2, 8, 8, 32> Policy;
};

/** @} */

/**
Expand Down
14 changes: 8 additions & 6 deletions cpp/include/raft/linalg/detail/contractions.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,7 +72,9 @@ struct Contractions_NT {
/** block of Y data loaded from global mem after `ldgXY()` */
DataT ldgDataY[P::LdgPerThY][P::Veclen];

static constexpr DataT Zero = (DataT)0;
// static constexpr DataT Zero = DataT{0};

static constexpr DataT Zero() { return DataT{0}; }

public:
/**
Expand Down Expand Up @@ -197,7 +199,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
ldgDataX[i][j] = Zero();
}
}
}
Expand All @@ -211,7 +213,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataX[i][j] = Zero;
ldgDataX[i][j] = Zero();
}
}
}
Expand All @@ -235,7 +237,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
ldgDataY[i][j] = Zero();
}
}
}
Expand All @@ -249,7 +251,7 @@ struct Contractions_NT {
} else {
#pragma unroll
for (int j = 0; j < P::Veclen; ++j) {
ldgDataY[i][j] = Zero;
ldgDataY[i][j] = Zero();
}
}
}
Expand Down
Loading
Loading