Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 90 additions & 40 deletions cpp/include/raft/core/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,15 @@ template <typename T>
inline constexpr bool is_mdspan_v = is_mdspan_t<T>::value;
} // namespace detail

template <typename...>
struct is_mdspan : std::true_type {
};
template <typename T1>
struct is_mdspan<T1> : detail::is_mdspan_t<T1> {
};
template <typename T1, typename... Tn>
struct is_mdspan<T1, Tn...>
: std::conditional_t<detail::is_mdspan_v<T1>, is_mdspan<Tn...>, std::false_type> {
};

/**
* @\brief Boolean to determine if variadic template types Tn are either
* raft::host_mdspan/raft::device_mdspan or their derived types
*/
template <typename... Tn>
inline constexpr bool is_mdspan_v = is_mdspan<Tn...>::value;
inline constexpr bool is_mdspan_v = std::conjunction_v<detail::is_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_mdspan = std::enable_if_t<is_mdspan_v<Tn...>>;

/**
* @brief stdex::mdspan with device tag to avoid accessing incorrect memory location.
Expand All @@ -160,69 +152,83 @@ template <typename ElementType,
using host_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::host_accessor<AccessorPolicy>>;

template <typename ElementType,
typename Extents,
typename LayoutPolicy = layout_c_contiguous,
typename AccessorPolicy = detail::stdex::default_accessor<ElementType>>
using managed_mdspan =
mdspan<ElementType, Extents, LayoutPolicy, detail::managed_accessor<AccessorPolicy>>;

namespace detail {
template <typename T, bool B>
struct is_device_mdspan : std::false_type {
};
template <typename T>
struct is_device_mdspan<T, true> : std::bool_constant<not T::accessor_type::is_host_type::value> {
struct is_device_mdspan<T, true> : std::bool_constant<T::accessor_type::is_device_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<T, is_mdspan_v<T>>::value;
using is_device_mdspan_t = is_device_mdspan<T, is_mdspan_v<T>>;

template <typename T, bool B>
struct is_host_mdspan : std::false_type {
};
template <typename T>
struct is_host_mdspan<T, true> : T::accessor_type::is_host_type {
struct is_host_mdspan<T, true> : std::bool_constant<T::accessor_type::is_host_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type
*/
template <typename T>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<T, is_mdspan_v<T>>::value;
} // namespace detail
using is_host_mdspan_t = is_host_mdspan<T, is_mdspan_v<T>>;

template <typename...>
struct is_device_mdspan : std::true_type {
};
template <typename T1>
struct is_device_mdspan<T1> : detail::is_device_mdspan<T1, detail::is_mdspan_v<T1>> {
template <typename T, bool B>
struct is_managed_mdspan : std::false_type {
};
template <typename T1, typename... Tn>
struct is_device_mdspan<T1, Tn...>
: std::conditional_t<detail::is_device_mdspan_v<T1>, is_device_mdspan<Tn...>, std::false_type> {
template <typename T>
struct is_managed_mdspan<T, true> : std::bool_constant<T::accessor_type::is_managed_accessible> {
};

/**
* @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type
*/
template <typename T>
using is_managed_mdspan_t = is_managed_mdspan<T, is_mdspan_v<T>>;
} // namespace detail

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_device_mdspan_v = is_device_mdspan<Tn...>::value;
inline constexpr bool is_device_mdspan_v = std::conjunction_v<detail::is_device_mdspan_t<Tn>...>;

template <typename...>
struct is_host_mdspan : std::true_type {
};
template <typename T1>
struct is_host_mdspan<T1> : detail::is_host_mdspan<T1, detail::is_mdspan_v<T1>> {
};
template <typename T1, typename... Tn>
struct is_host_mdspan<T1, Tn...>
: std::conditional_t<detail::is_host_mdspan_v<T1>, is_host_mdspan<Tn...>, std::false_type> {
};
template <typename... Tn>
using enable_if_device_mdspan = std::enable_if_t<is_device_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_host_mdspan_v = is_host_mdspan<Tn...>::value;
inline constexpr bool is_host_mdspan_v = std::conjunction_v<detail::is_host_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_host_mdspan = std::enable_if_t<is_host_mdspan_v<Tn...>>;

/**
* @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a
* derived type
*/
template <typename... Tn>
inline constexpr bool is_managed_mdspan_v = std::conjunction_v<detail::is_managed_mdspan_t<Tn>...>;

template <typename... Tn>
using enable_if_managed_mdspan = std::enable_if_t<is_managed_mdspan_v<Tn...>>;

/**
* @brief Interface to implement an owning multi-dimensional array
Expand Down Expand Up @@ -348,7 +354,7 @@ class mdarray
typename container_policy_type::const_accessor_policy,
typename container_policy_type::accessor_policy>>
using view_type_impl =
std::conditional_t<container_policy_type::is_host_type::value,
std::conditional_t<container_policy_type::is_host_accessible,
host_mdspan<E, extents_type, layout_type, ViewAccessorPolicy>,
device_mdspan<E, extents_type, layout_type, ViewAccessorPolicy>>;

Expand Down Expand Up @@ -672,6 +678,50 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous>
using device_matrix_view = device_mdspan<ElementType, matrix_extent<IndexType>, LayoutPolicy>;

/**
* @brief Create a raft::mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @tparam is_host_accessible whether the data is accessible on host
* @tparam is_device_accessible whether the data is accessible on device
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
bool is_host_accessible = false,
bool is_device_accessible = true,
size_t... Extents>
auto make_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
using accessor_type = detail::accessor_mixin<std::experimental::default_accessor<ElementType>,
is_host_accessible,
is_device_accessible>;

return mdspan<ElementType, decltype(exts), LayoutPolicy, accessor_type>{ptr, exts};
}

/**
* @brief Create a raft::managed_mdspan
* @tparam ElementType the data type of the matrix elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param ptr Pointer to the data
* @param exts dimensionality of the array (series of integers)
* @return raft::managed_mdspan
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_managed_mdspan(ElementType* ptr, extents<IndexType, Extents...> exts)
{
return make_mdspan<ElementType, IndexType, LayoutPolicy, true, true>(ptr, exts);
}

/**
* @brief Create a 0-dim (scalar) mdspan instance for host value.
*
Expand Down Expand Up @@ -983,7 +1033,7 @@ auto make_device_vector(raft::handle_t const& handle, IndexType n)
* @return raft::host_mdspan or raft::device_mdspan with vector_extent
* depending on AccessoryPolicy
*/
template <typename mdspan_type, std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
template <typename mdspan_type, typename = enable_if_mdspan<mdspan_type>>
auto flatten(mdspan_type mds)
{
RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous.");
Expand Down Expand Up @@ -1024,7 +1074,7 @@ auto flatten(const array_interface_type& mda)
template <typename mdspan_type,
typename IndexType = std::uint32_t,
size_t... Extents,
std::enable_if_t<is_mdspan_v<mdspan_type>>* = nullptr>
typename = enable_if_mdspan<mdspan_type>>
auto reshape(mdspan_type mds, extents<IndexType, Extents...> new_shape)
{
RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous.");
Expand Down
18 changes: 13 additions & 5 deletions cpp/include/raft/detail/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,29 @@ class host_vector_policy {
/**
* @brief A mixin to distinguish host and device memory.
*/
template <typename AccessorPolicy, bool is_host>
template <typename AccessorPolicy, bool is_host, bool is_device>
struct accessor_mixin : public AccessorPolicy {
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using accessor_type = AccessorPolicy;
using is_host_type = std::conditional_t<is_host, std::true_type, std::false_type>;
using is_device_type = std::conditional_t<is_device, std::true_type, std::false_type>;
using is_managed_type = std::conditional_t<is_device && is_host, std::true_type, std::false_type>;
static constexpr bool is_host_accessible = is_host;
static constexpr bool is_device_accessible = is_device;
static constexpr bool is_managed_accessible = is_device && is_host;
// make sure the explicit ctor can fall through
using AccessorPolicy::AccessorPolicy;
using offset_policy = accessor_mixin;
accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT
};

template <typename AccessorPolicy>
using host_accessor = accessor_mixin<AccessorPolicy, true>;
using host_accessor = accessor_mixin<AccessorPolicy, true, false>;

template <typename AccessorPolicy>
using device_accessor = accessor_mixin<AccessorPolicy, false>;
using device_accessor = accessor_mixin<AccessorPolicy, false, true>;

template <typename AccessorPolicy>
using managed_accessor = accessor_mixin<AccessorPolicy, true, true>;

namespace stdex = std::experimental;

Expand Down
12 changes: 12 additions & 0 deletions cpp/test/mdarray.cu
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,18 @@ void test_factory_methods()
auto view = make_host_scalar_view(h_scalar.data_handle());
ASSERT_EQ(view(0), 17.0);
}

// managed
{
raft::handle_t handle{};
auto mda = make_device_vector<int>(handle, 10);

auto mdv = make_managed_mdspan(mda.data_handle(), raft::vector_extent<int>{10});

static_assert(decltype(mdv)::accessor_type::is_managed_accessible, "Not managed mdspan");

ASSERT_EQ(mdv.size(), 10);
}
}
} // anonymous namespace

Expand Down
16 changes: 2 additions & 14 deletions cpp/test/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void test_template_asserts()

// Checking if types are host_mdspan
static_assert(!is_host_mdspan_v<device_matrix_view<float>>,
"device_matrix_view type not a host_mdspan");
"device_matrix_view type is a host_mdspan");
static_assert(is_host_mdspan_v<host_matrix_view<float>>,
"host_matrix_view type is a host_mdspan");
"host_matrix_view type is not a host_mdspan");

// checking variadics
static_assert(!is_mdspan_v<three_d_mdspan, std::vector<int>>, "variadics mdspans");
Expand Down Expand Up @@ -171,12 +171,6 @@ void test_reshape()
three_d_mdarray mda{layout, policy};

auto flat_view = reshape(mda, raft::extents<int, dynamic_extent>{27});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(flat_view),
host_vector_view<typename decltype(flat_view)::element_type,
typename decltype(flat_view)::index_type,
typename decltype(flat_view)::layout_type>>,
"types not the same");

ASSERT_EQ(flat_view.extents().rank(), 1);
ASSERT_EQ(flat_view.size(), mda.size());
Expand All @@ -195,12 +189,6 @@ void test_reshape()
four_d_mdarray mda{layout, policy};

auto matrix = reshape(mda, raft::extents<int, dynamic_extent, dynamic_extent>{4, 4});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(matrix),
device_matrix_view<typename decltype(matrix)::element_type,
typename decltype(matrix)::index_type,
typename decltype(matrix)::layout_type>>,
"types not the same");

ASSERT_EQ(matrix.extents().rank(), 2);
ASSERT_EQ(matrix.extent(0), 4);
Expand Down