Skip to content

Commit 182c3fd

Browse files
committed
[FEA] Add support for SDDMM by wrapping the cusparseSDDMM (#2067)
- Add support for SDDMM by wrapping the `cusparseSDDMM` - This PR also moved some APIs shared with `SpMM` to the `utils.cuh` file. Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) URL: #2067
1 parent 1beb556 commit 182c3fd

File tree

8 files changed

+629
-72
lines changed

8 files changed

+629
-72
lines changed

cpp/include/raft/sparse/detail/cusparse_wrappers.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,118 @@ inline cusparseStatus_t cusparsespmm(cusparseHandle_t handle,
571571
alg,
572572
static_cast<void*>(externalBuffer));
573573
}
574+
575+
template <typename T>
576+
cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
577+
cusparseOperation_t opA,
578+
cusparseOperation_t opB,
579+
const T* alpha,
580+
const cusparseDnMatDescr_t matA,
581+
const cusparseDnMatDescr_t matB,
582+
const T* beta,
583+
cusparseSpMatDescr_t matC,
584+
cusparseSDDMMAlg_t alg,
585+
size_t* bufferSize,
586+
cudaStream_t stream);
587+
template <>
588+
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
589+
cusparseOperation_t opA,
590+
cusparseOperation_t opB,
591+
const float* alpha,
592+
const cusparseDnMatDescr_t matA,
593+
const cusparseDnMatDescr_t matB,
594+
const float* beta,
595+
cusparseSpMatDescr_t matC,
596+
cusparseSDDMMAlg_t alg,
597+
size_t* bufferSize,
598+
cudaStream_t stream)
599+
{
600+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
601+
return cusparseSDDMM_bufferSize(
602+
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_32F, alg, bufferSize);
603+
}
604+
template <>
605+
inline cusparseStatus_t cusparsesddmm_bufferSize(cusparseHandle_t handle,
606+
cusparseOperation_t opA,
607+
cusparseOperation_t opB,
608+
const double* alpha,
609+
const cusparseDnMatDescr_t matA,
610+
const cusparseDnMatDescr_t matB,
611+
const double* beta,
612+
cusparseSpMatDescr_t matC,
613+
cusparseSDDMMAlg_t alg,
614+
size_t* bufferSize,
615+
cudaStream_t stream)
616+
{
617+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
618+
return cusparseSDDMM_bufferSize(
619+
handle, opA, opB, alpha, matA, matB, beta, matC, CUDA_R_64F, alg, bufferSize);
620+
}
621+
template <typename T>
622+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
623+
cusparseOperation_t opA,
624+
cusparseOperation_t opB,
625+
const T* alpha,
626+
const cusparseDnMatDescr_t matA,
627+
const cusparseDnMatDescr_t matB,
628+
const T* beta,
629+
cusparseSpMatDescr_t matC,
630+
cusparseSDDMMAlg_t alg,
631+
T* externalBuffer,
632+
cudaStream_t stream);
633+
template <>
634+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
635+
cusparseOperation_t opA,
636+
cusparseOperation_t opB,
637+
const float* alpha,
638+
const cusparseDnMatDescr_t matA,
639+
const cusparseDnMatDescr_t matB,
640+
const float* beta,
641+
cusparseSpMatDescr_t matC,
642+
cusparseSDDMMAlg_t alg,
643+
float* externalBuffer,
644+
cudaStream_t stream)
645+
{
646+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
647+
return cusparseSDDMM(handle,
648+
opA,
649+
opB,
650+
static_cast<void const*>(alpha),
651+
matA,
652+
matB,
653+
static_cast<void const*>(beta),
654+
matC,
655+
CUDA_R_32F,
656+
alg,
657+
static_cast<void*>(externalBuffer));
658+
}
659+
template <>
660+
inline cusparseStatus_t cusparsesddmm(cusparseHandle_t handle,
661+
cusparseOperation_t opA,
662+
cusparseOperation_t opB,
663+
const double* alpha,
664+
const cusparseDnMatDescr_t matA,
665+
const cusparseDnMatDescr_t matB,
666+
const double* beta,
667+
cusparseSpMatDescr_t matC,
668+
cusparseSDDMMAlg_t alg,
669+
double* externalBuffer,
670+
cudaStream_t stream)
671+
{
672+
CUSPARSE_CHECK(cusparseSetStream(handle, stream));
673+
return cusparseSDDMM(handle,
674+
opA,
675+
opB,
676+
static_cast<void const*>(alpha),
677+
matA,
678+
matB,
679+
static_cast<void const*>(beta),
680+
matC,
681+
CUDA_R_64F,
682+
alg,
683+
static_cast<void*>(externalBuffer));
684+
}
685+
574686
/** @} */
575687
#else
576688
/**
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright (c) 2023, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <raft/core/device_mdarray.hpp>
19+
#include <raft/core/device_mdspan.hpp>
20+
#include <raft/core/host_mdspan.hpp>
21+
#include <raft/core/resource/cuda_stream.hpp>
22+
#include <raft/core/resource/cusparse_handle.hpp>
23+
#include <raft/core/resources.hpp>
24+
#include <raft/sparse/detail/cusparse_wrappers.h>
25+
26+
namespace raft {
27+
namespace sparse {
28+
namespace linalg {
29+
namespace detail {
30+
31+
/**
32+
* @brief This function performs the multiplication of dense matrix A and dense matrix B,
33+
* followed by an element-wise multiplication with the sparsity pattern of C.
34+
* It computes the following equation: C = alpha · (A * B ∘ spy(C)) + beta · C
35+
* where A,B are device matrix views and C is a CSR device matrix view
36+
*
37+
* @tparam ValueType Data type of input/output matrices (float/double)
38+
* @tparam IndexType Type of C
39+
* @tparam LayoutPolicyA layout of A
40+
* @tparam LayoutPolicyB layout of B
41+
* @tparam NZType Type of C
42+
*
43+
* @param[in] handle raft resource handle
44+
* @param[in] trans_a transpose operation for A
45+
* @param[in] trans_b transpose operation for B
46+
* @param[in] is_row_major data layout of A,B
47+
* @param[in] alpha scalar pointer
48+
* @param[in] descr_a input dense descriptor
49+
* @param[in] descr_b input dense descriptor
50+
* @param[in] beta scalar pointer
51+
* @param[out] descr_c output sparse descriptor
52+
*/
53+
template <typename ValueType>
54+
void sddmm(raft::resources const& handle,
55+
const bool trans_a,
56+
const bool trans_b,
57+
const bool is_row_major,
58+
const ValueType* alpha,
59+
cusparseDnMatDescr_t& descr_a,
60+
cusparseDnMatDescr_t& descr_b,
61+
const ValueType* beta,
62+
cusparseSpMatDescr_t& descr_c)
63+
{
64+
auto opA = trans_a ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE;
65+
auto opB = trans_b ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE;
66+
auto alg = CUSPARSE_SDDMM_ALG_DEFAULT;
67+
size_t bufferSize;
68+
RAFT_CUSPARSE_TRY(
69+
raft::sparse::detail::cusparsesddmm_bufferSize(resource::get_cusparse_handle(handle),
70+
opA,
71+
opB,
72+
alpha,
73+
descr_a,
74+
descr_b,
75+
beta,
76+
descr_c,
77+
alg,
78+
&bufferSize,
79+
resource::get_cuda_stream(handle)));
80+
81+
raft::interruptible::synchronize(resource::get_cuda_stream(handle));
82+
83+
rmm::device_uvector<ValueType> tmp(bufferSize, resource::get_cuda_stream(handle));
84+
85+
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsesddmm(resource::get_cusparse_handle(handle),
86+
opA,
87+
opB,
88+
alpha,
89+
descr_a,
90+
descr_b,
91+
beta,
92+
descr_c,
93+
alg,
94+
tmp.data(),
95+
resource::get_cuda_stream(handle)));
96+
}
97+
98+
} // end namespace detail
99+
} // end namespace linalg
100+
} // end namespace sparse
101+
} // end namespace raft

cpp/include/raft/sparse/linalg/detail/spmm.hpp

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -28,78 +28,6 @@ namespace sparse {
2828
namespace linalg {
2929
namespace detail {
3030

31-
/**
32-
* @brief determine common data layout for both dense matrices
33-
* @tparam ValueType Data type of Y,Z (float/double)
34-
* @tparam IndexType Type of Y,Z
35-
* @tparam LayoutPolicyY layout of Y
36-
* @tparam LayoutPolicyZ layout of Z
37-
* @param[in] x input raft::device_matrix_view
38-
* @param[in] y input raft::device_matrix_view
39-
* @returns dense matrix descriptor to be used by cuSparse API
40-
*/
41-
template <typename ValueType, typename IndexType, typename LayoutPolicyY, typename LayoutPolicyZ>
42-
bool is_row_major(raft::device_matrix_view<const ValueType, IndexType, LayoutPolicyY>& y,
43-
raft::device_matrix_view<ValueType, IndexType, LayoutPolicyZ>& z)
44-
{
45-
bool is_row_major = z.stride(1) == 1 && y.stride(1) == 1;
46-
bool is_col_major = z.stride(0) == 1 && y.stride(0) == 1;
47-
ASSERT(is_row_major || is_col_major, "Both matrices need to be either row or col major");
48-
return is_row_major;
49-
}
50-
51-
/**
52-
* @brief create a cuSparse dense descriptor
53-
* @tparam ValueType Data type of dense_view (float/double)
54-
* @tparam IndexType Type of dense_view
55-
* @tparam LayoutPolicy layout of dense_view
56-
* @param[in] dense_view input raft::device_matrix_view
57-
* @param[in] is_row_major data layout of raft::device_matrix_view
58-
* @returns dense matrix descriptor to be used by cuSparse API
59-
*/
60-
template <typename ValueType, typename IndexType, typename LayoutPolicy>
61-
cusparseDnMatDescr_t create_descriptor(
62-
raft::device_matrix_view<ValueType, IndexType, LayoutPolicy>& dense_view, const bool is_row_major)
63-
{
64-
auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL;
65-
IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1);
66-
cusparseDnMatDescr_t descr;
67-
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat(
68-
&descr,
69-
dense_view.extent(0),
70-
dense_view.extent(1),
71-
ld,
72-
const_cast<std::remove_const_t<ValueType>*>(dense_view.data_handle()),
73-
order));
74-
return descr;
75-
}
76-
77-
/**
78-
* @brief create a cuSparse sparse descriptor
79-
* @tparam ValueType Data type of sparse_view (float/double)
80-
* @tparam IndptrType Data type of csr_matrix_view index pointers
81-
* @tparam IndicesType Data type of csr_matrix_view indices
82-
* @tparam NZType Type of sparse_view
83-
* @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns
84-
* @returns sparse matrix descriptor to be used by cuSparse API
85-
*/
86-
template <typename ValueType, typename IndptrType, typename IndicesType, typename NZType>
87-
cusparseSpMatDescr_t create_descriptor(
88-
raft::device_csr_matrix_view<ValueType, IndptrType, IndicesType, NZType>& sparse_view)
89-
{
90-
cusparseSpMatDescr_t descr;
91-
auto csr_structure = sparse_view.structure_view();
92-
RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr(
93-
&descr,
94-
static_cast<int64_t>(csr_structure.get_n_rows()),
95-
static_cast<int64_t>(csr_structure.get_n_cols()),
96-
static_cast<int64_t>(csr_structure.get_nnz()),
97-
const_cast<IndptrType*>(csr_structure.get_indptr().data()),
98-
const_cast<IndicesType*>(csr_structure.get_indices().data()),
99-
const_cast<std::remove_const_t<ValueType>*>(sparse_view.get_elements().data())));
100-
return descr;
101-
}
102-
10331
/**
10432
* @brief SPMM function designed for handling all CSR * DENSE
10533
* combinations of operand layouts for cuSparse.

0 commit comments

Comments
 (0)