Skip to content

Commit cdae2b5

Browse files
committed
[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067)
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
1 parent 6762fe5 commit cdae2b5

File tree

9 files changed

+836
-56
lines changed

9 files changed

+836
-56
lines changed

cpp/include/raft/linalg/linalg_types.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,4 +32,11 @@ enum class Apply { ALONG_ROWS, ALONG_COLUMNS };
3232
*/
3333
enum class FillMode { UPPER, LOWER };
3434

35+
/**
36+
* @brief Enum for this type indicates which operations is applied to the related input (e.g. sparse
37+
* matrix, or vector).
38+
*
39+
*/
40+
enum class Operation { NON_TRANSPOSE, TRANSPOSE };
41+
3542
} // end namespace raft::linalg

cpp/include/raft/sparse/detail/cusparse_wrappers.h

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle,
571571
alg,
572572
static_cast<void*>(externalBuffer));
573573
}
574+
575+
template <typename T>
576+
cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
577+
cusparseOperation_t opA,
578+
cusparseOperation_t opB,
579+
const T* alpha,
580+
const cusparseDnMatDescr_t matA,
581+
const cusparseDnMatDescr_t matB,
582+
const T* beta,
583+
cusparseSpMatDescr_t matC,
584+
cusparseSDDMMAlg_t alg,
585+
size_t* bufferSize,
586+
cudaStream_t stream);
587+
template <>
588+
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
589+
cusparseOperation_t opA,
590+
cusparseOperation_t opB,
591+
const float* alpha,
592+
const cusparseDnMatDescr_t matA,
593+
const cusparseDnMatDescr_t matB,
594+
const float* beta,
595+
cusparseSpMatDescr_t matC,
596+
cusparseSDDMMAlg_t alg,
597+
size_t* bufferSize,
598+
cudaStream_t stream)
599+
{
600+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
601+
return cusparseSDDMM_bufferSize(
602+
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize);
603+
}
604+
template <>
605+
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
606+
cusparseOperation_t opA,
607+
cusparseOperation_t opB,
608+
const double* alpha,
609+
const cusparseDnMatDescr_t matA,
610+
const cusparseDnMatDescr_t matB,
611+
const double* beta,
612+
cusparseSpMatDescr_t matC,
613+
cusparseSDDMMAlg_t alg,
614+
size_t* bufferSize,
615+
cudaStream_t stream)
616+
{
617+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
618+
return cusparseSDDMM_bufferSize(
619+
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize);
620+
}
621+
template <typename T>
622+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
623+
cusparseOperation_t opA,
624+
cusparseOperation_t opB,
625+
const T* alpha,
626+
const cusparseDnMatDescr_t matA,
627+
const cusparseDnMatDescr_t matB,
628+
const T* beta,
629+
cusparseSpMatDescr_t matC,
630+
cusparseSDDMMAlg_t alg,
631+
T* externalBuffer,
632+
cudaStream_t stream);
633+
template <>
634+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
635+
cusparseOperation_t opA,
636+
cusparseOperation_t opB,
637+
const float* alpha,
638+
const cusparseDnMatDescr_t matA,
639+
const cusparseDnMatDescr_t matB,
640+
const float* beta,
641+
cusparseSpMatDescr_t matC,
642+
cusparseSDDMMAlg_t alg,
643+
float* externalBuffer,
644+
cudaStream_t stream)
645+
{
646+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
647+
return cusparseSDDMM(handle,
648+
opA,
649+
opB,
650+
static_cast<void const*>(alpha),
651+
matA,
652+
matB,
653+
static_cast<void const*>(beta),
654+
matC,
655+
CUDA_R_32F,
656+
alg,
657+
static_cast<void*>(externalBuffer));
658+
}
659+
template <>
660+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
661+
cusparseOperation_t opA,
662+
cusparseOperation_t opB,
663+
const double* alpha,
664+
const cusparseDnMatDescr_t matA,
665+
const cusparseDnMatDescr_t matB,
666+
const double* beta,
667+
cusparseSpMatDescr_t matC,
668+
cusparseSDDMMAlg_t alg,
669+
double* externalBuffer,
670+
cudaStream_t stream)
671+
{
672+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
673+
return cusparseSDDMM(handle,
674+
opA,
675+
opB,
676+
static_cast<void const*>(alpha),
677+
matA,
678+
matB,
679+
static_cast<void const*>(beta),
680+
matC,
681+
CUDA_R_64F,
682+
alg,
683+
static_cast<void*>(externalBuffer));
684+
}
685+
574686
/** @} */
575687
#else
576688
/**
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <raft/core/device_mdarray.hpp>
19+
#include <raft/core/device_mdspan.hpp>
20+
#include <raft/core/host_mdspan.hpp>
21+
#include <raft/core/resource/cuda_stream.hpp>
22+
#include <raft/core/resource/cusparse_handle.hpp>
23+
#include <raft/core/resources.hpp>
24+
#include <raft/linalg/linalg_types.hpp>
25+
#include <raft/sparse/detail/cusparse_wrappers.h>
26+
27+
namespace raft {
28+
namespace sparse {
29+
namespace linalg {
30+
namespace detail {
31+
32+
/**
33+
* @brief This function performs the multiplication of dense matrix A and dense matrix B,
34+
* followed by an element-wise multiplication with the sparsity pattern of C.
35+
* It computes the following equation: C = alpha · (op_a(A) * op_b(B) ∘ spy(C)) + beta · C
36+
* where A,B are device matrix views and C is a CSR device matrix view
37+
*
38+
* @tparam ValueType Data type of input/output matrices (float/double)
39+
* @tparam IndexType Type of C
40+
* @tparam LayoutPolicyA layout of A
41+
* @tparam LayoutPolicyB layout of B
42+
* @tparam NZType Type of C
43+
*
44+
* @param[in] handle raft resource handle
45+
* @param[in] descr_a input dense descriptor
46+
* @param[in] descr_b input dense descriptor
47+
* @param[in/out] descr_c output sparse descriptor
48+
* @param[in] op_a input Operation op(A)
49+
* @param[in] op_b input Operation op(B)
50+
* @param[in] alpha scalar pointer
51+
* @param[in] beta scalar pointer
52+
*/
53+
template <typename ValueType>
54+
void sddmm(raft::resources const& handle,
55+
cusparseDnMatDescr_t& descr_a,
56+
cusparseDnMatDescr_t& descr_b,
57+
cusparseSpMatDescr_t& descr_c,
58+
cusparseOperation_t op_a,
59+
cusparseOperation_t op_b,
60+
const ValueType* alpha,
61+
const ValueType* beta)
62+
{
63+
auto alg = CUSPARSE_SDDMM_ALG_DEFAULT;
64+
size_t bufferSize;
65+
66+
RAFT_CUSPARSE_TRY(
67+
raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle),
68+
op_a,
69+
op_b,
70+
alpha,
71+
descr_a,
72+
descr_b,
73+
beta,
74+
descr_c,
75+
alg,
76+
&bufferSize,
77+
resource::get_cuda_stream(handle)));
78+
79+
raft::interruptible::synchronize(resource::get_cuda_stream(handle));
80+
81+
rmm::device_uvector<ValueType> tmp(bufferSize, resource::get_cuda_stream(handle));
82+
83+
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle),
84+
op_a,
85+
op_b,
86+
alpha,
87+
descr_a,
88+
descr_b,
89+
beta,
90+
descr_c,
91+
alg,
92+
tmp.data(),
93+
resource::get_cuda_stream(handle)));
94+
}
95+
96+
} // end namespace detail
97+
} // end namespace linalg
98+
} // end namespace sparse
99+
} // end namespace raft

cpp/include/raft/sparse/linalg/detail/spmm.hpp

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -48,58 +48,6 @@ bool is_row_major(raft::device_matrix_view<const ValueType, IndexType, LayoutPol
4848
return is_row_major;
4949
}
5050

51-
/**
52-
* @brief create a cuSparse dense descriptor
53-
* @tparam ValueType Data type of dense_view (float/double)
54-
* @tparam IndexType Type of dense_view
55-
* @tparam LayoutPolicy layout of dense_view
56-
* @param[in] dense_view input raft::device_matrix_view
57-
* @param[in] is_row_major data layout of raft::device_matrix_view
58-
* @returns dense matrix descriptor to be used by cuSparse API
59-
*/
60-
template <typename ValueType, typename IndexType, typename LayoutPolicy>
61-
cusparseDnMatDescr_t create_descriptor(
62-
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
63-
{
64-
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
65-
IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1);
66-
cusparseDnMatDescr_t descr;
67-
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
68-
&descr,
69-
dense_view.extent(0),
70-
dense_view.extent(1),
71-
ld,
72-
const_cast<std::remove_const_t<ValueType>*>(dense_view.data_handle()),
73-
order));
74-
return descr;
75-
}
76-
77-
/**
78-
* @brief create a cuSparse sparse descriptor
79-
* @tparam ValueType Data type of sparse_view (float/double)
80-
* @tparam IndptrType Data type of csr_matrix_view index pointers
81-
* @tparam IndicesType Data type of csr_matrix_view indices
82-
* @tparam NZType Type of sparse_view
83-
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
84-
* @returns sparse matrix descriptor to be used by cuSparse API
85-
*/
86-
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
87-
cusparseSpMatDescr_t create_descriptor(
88-
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
89-
{
90-
cusparseSpMatDescr_t descr;
91-
auto csr_structure = sparse_view.structure_view();
92-
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
93-
&descr,
94-
static_cast<int64_t>(csr_structure.get_n_rows()),
95-
static_cast<int64_t>(csr_structure.get_n_cols()),
96-
static_cast<int64_t>(csr_structure.get_nnz()),
97-
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
98-
const_cast<IndicesType*>(csr_structure.get_indices().data()),
99-
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
100-
return descr;
101-
}
102-
10351
/**
10452
* @brief SPMM function designed for handling all CSR * DENSE
10553
* combinations of operand layouts for cuSparse.

0 commit comments

Comments
 (0)