Skip to content

Commit

Permalink
[CUDAX] Rename hierarchy_dimensions_fragment to `hierarchy_dimensio…
Browse files Browse the repository at this point in the history
…ns` and remove the old alias (#3496)

* Remove hierarchy_dimensions_fragment

* Fix format
  • Loading branch information
pciolkosz authored Jan 25, 2025
1 parent 32ef7af commit 3a15507
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,8 @@ _CCCL_NODISCARD _CUDAX_API constexpr auto __as_level(_LevelFn* __fn) noexcept ->
template <class _Level>
using __level_type_of = typename _Level::level_type;

template <typename BottomUnit, typename... Levels>
struct hierarchy_dimensions_fragment;

// If lowest unit in the hierarchy is thread, it can be considered a full hierarchy and not only a fragment
template <typename... Levels>
using hierarchy_dimensions = hierarchy_dimensions_fragment<thread_level, Levels...>;
template <typename BottomUnit = thread_level, typename... Levels>
struct hierarchy_dimensions;

namespace detail
{
Expand All @@ -86,7 +82,7 @@ template <typename QueryLevel, typename Hierarchy>
struct has_level_helper;

template <typename QueryLevel, typename Unit, typename... Levels>
struct has_level_helper<QueryLevel, hierarchy_dimensions_fragment<Unit, Levels...>>
struct has_level_helper<QueryLevel, hierarchy_dimensions<Unit, Levels...>>
: public ::cuda::std::__fold_or<::cuda::std::is_same_v<QueryLevel, __level_type_of<Levels>>...>
{};

Expand All @@ -101,7 +97,7 @@ struct has_unit
{};

template <typename QueryLevel, typename Unit, typename... Levels>
struct has_unit<QueryLevel, hierarchy_dimensions_fragment<Unit, Levels...>> : ::cuda::std::is_same<QueryLevel, Unit>
struct has_unit<QueryLevel, hierarchy_dimensions<Unit, Levels...>> : ::cuda::std::is_same<QueryLevel, Unit>
{};

template <typename QueryLevel>
Expand Down Expand Up @@ -153,13 +149,13 @@ _CUDAX_API constexpr auto __reverse_indices(::cuda::std::index_sequence<_Id...>)
}

template <typename LUnit, bool Reversed = false>
struct __make_hierarchy_fragment
struct __make_hierarchy
{
template <class Levels, size_t... _Ids>
_CCCL_NODISCARD _CUDAX_TRIVIAL_API static constexpr auto
__apply_reverse(const Levels& ls, ::cuda::std::index_sequence<_Ids...>) noexcept
{
return __make_hierarchy_fragment<LUnit, true>()(::cuda::std::get<_Ids>(ls)...);
return __make_hierarchy<LUnit, true>()(::cuda::std::get<_Ids>(ls)...);
}

template <typename... Levels>
Expand All @@ -171,7 +167,7 @@ struct __make_hierarchy_fragment
LUnit>;
if constexpr (__can_stack<UnitOrDefault, Levels...>)
{
return hierarchy_dimensions_fragment(UnitOrDefault{}, ls...);
return hierarchy_dimensions(UnitOrDefault{}, ls...);
}
else if constexpr (!Reversed)
{
Expand Down Expand Up @@ -357,8 +353,7 @@ struct __empty_hierarchy
* This type combines a number of level_dimensions objects to represent dimensions of a (possibly partial)
* hierarchy of CUDA threads. It supports accessing individual levels or queries combining dimensions
* of multiple levels.
* This type should not be created directly and make_hierarchy or make_hierarchy_fragment functions
* should be used instead.
* This type should not be created directly and make_hierarchy function should be used instead.
* For every level, the unit for its dimensions is implied by the next level in the hierarchy, except
* for the last type, for which its the BottomUnit template argument.
* In case the BottomUnit type is thread_level, the hierarchy is considered complete and there
Expand All @@ -382,38 +377,37 @@ struct __empty_hierarchy
* level_dimensions instances or types derived from it
*/
template <typename BottomUnit, typename... Levels>
struct hierarchy_dimensions_fragment
struct hierarchy_dimensions
{
static_assert(::cuda::std::is_base_of_v<hierarchy_level, BottomUnit> || ::cuda::std::is_same_v<BottomUnit, void>);
::cuda::std::tuple<Levels...> levels;

_CUDAX_API constexpr hierarchy_dimensions_fragment(const Levels&... ls) noexcept
_CUDAX_API constexpr hierarchy_dimensions(const Levels&... ls) noexcept
: levels(ls...)
{}
_CUDAX_API constexpr hierarchy_dimensions_fragment(const BottomUnit&, const Levels&... ls) noexcept
_CUDAX_API constexpr hierarchy_dimensions(const BottomUnit&, const Levels&... ls) noexcept
: levels(ls...)
{}

_CUDAX_API constexpr hierarchy_dimensions_fragment(const ::cuda::std::tuple<Levels...>& ls) noexcept
_CUDAX_API constexpr hierarchy_dimensions(const ::cuda::std::tuple<Levels...>& ls) noexcept
: levels(ls)
{}

_CUDAX_API constexpr hierarchy_dimensions_fragment(const BottomUnit&, const ::cuda::std::tuple<Levels...>& ls) noexcept
_CUDAX_API constexpr hierarchy_dimensions(const BottomUnit&, const ::cuda::std::tuple<Levels...>& ls) noexcept
: levels(ls)
{}

# if !defined(_CCCL_NO_THREE_WAY_COMPARISON) && !_CCCL_COMPILER(MSVC, <, 19, 39) && !_CCCL_COMPILER(GCC, <, 12)
_CCCL_NODISCARD _CCCL_HIDE_FROM_ABI constexpr bool
operator==(const hierarchy_dimensions_fragment&) const noexcept = default;
_CCCL_NODISCARD _CCCL_HIDE_FROM_ABI constexpr bool operator==(const hierarchy_dimensions&) const noexcept = default;
# else // ^^^ !_CCCL_NO_THREE_WAY_COMPARISON ^^^ / vvv _CCCL_NO_THREE_WAY_COMPARISON vvv
_CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
operator==(const hierarchy_dimensions_fragment& left, const hierarchy_dimensions_fragment& right) noexcept
operator==(const hierarchy_dimensions& left, const hierarchy_dimensions& right) noexcept
{
return left.levels == right.levels;
}

_CCCL_NODISCARD_FRIEND _CUDAX_API constexpr bool
operator!=(const hierarchy_dimensions_fragment& left, const hierarchy_dimensions_fragment& right) noexcept
operator!=(const hierarchy_dimensions& left, const hierarchy_dimensions& right) noexcept
{
return left.levels != right.levels;
}
Expand All @@ -425,8 +419,8 @@ private:
_CCCL_NODISCARD _CUDAX_API static constexpr auto
levels_range_static(const ::cuda::std::tuple<Levels...>& levels) noexcept
{
static_assert(has_level<Level, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);
static_assert(has_level_or_unit<Unit, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);
static_assert(has_level<Level, hierarchy_dimensions<BottomUnit, Levels...>>);
static_assert(has_level_or_unit<Unit, hierarchy_dimensions<BottomUnit, Levels...>>);
static_assert(detail::legal_unit_for_level<Unit, Level>);
return ::cuda::std::apply(detail::get_levels_range<Level, Unit, Levels...>, levels);
}
Expand All @@ -444,13 +438,13 @@ private:
template <typename... Selected>
_CCCL_NODISCARD _CUDAX_API constexpr auto operator()(const Selected&... levels) const noexcept
{
return hierarchy_dimensions_fragment<Unit, Selected...>(levels...);
return hierarchy_dimensions<Unit, Selected...>(levels...);
}
};

public:
template <typename, typename...>
friend struct hierarchy_dimensions_fragment;
friend struct hierarchy_dimensions;

template <typename Unit, typename Level>
using extents_type = decltype(::cuda::std::apply(
Expand All @@ -461,7 +455,7 @@ public:
* @brief Get a fragment of this hierarchy
*
* This member function can be used to get a fragment of the hierarchy its called on.
* It returns a hierarchy_dimensions_fragment that includes levels starting with the
* It returns a hierarchy_dimensions that includes levels starting with the
* level specified in Level and ending with a level before Unit. Toegether with
* hierarchy_add_level function it can be used to create a new hierarchy that is a modification
* of an existing hierarchy.
Expand Down Expand Up @@ -532,8 +526,8 @@ public:

// template <typename Unit, typename Level>
// using extents_type = ::cuda::std::invoke_result_t<
// decltype(&hierarchy_dimensions_fragment<BottomUnit, Levels...>::template extents<Unit, Level>),
// hierarchy_dimensions_fragment<BottomUnit, Levels...>,
// decltype(&hierarchy_dimensions<BottomUnit, Levels...>::template extents<Unit, Level>),
// hierarchy_dimensions<BottomUnit, Levels...>,
// Unit(),
// Level()>;

Expand Down Expand Up @@ -727,7 +721,7 @@ public:
template <typename Level>
_CUDAX_API constexpr auto level(const Level&) const noexcept
{
static_assert(has_level<Level, hierarchy_dimensions_fragment<BottomUnit, Levels...>>);
static_assert(has_level<Level, hierarchy_dimensions<BottomUnit, Levels...>>);

return ::cuda::std::apply(detail::get_level_helper<Level>{}, levels);
}
Expand All @@ -743,7 +737,7 @@ public:
//!
//! @return Hierarchy holding the combined levels from both hierarchies
template <typename OtherUnit, typename... OtherLevels>
constexpr auto combine(const hierarchy_dimensions_fragment<OtherUnit, OtherLevels...>& other) const
constexpr auto combine(const hierarchy_dimensions<OtherUnit, OtherLevels...>& other) const
{
using this_top_level = __level_type_of<::cuda::std::__type_index_c<0, Levels...>>;
using this_bottom_level = __level_type_of<::cuda::std::__type_index_c<sizeof...(Levels) - 1, Levels...>>;
Expand All @@ -754,8 +748,8 @@ public:
// Easily stackable case, example this is (grid), other is (cluster, block)
return ::cuda::std::apply(fragment_helper<OtherUnit>(), ::cuda::std::tuple_cat(levels, other.levels));
}
else if constexpr (has_level<this_bottom_level, hierarchy_dimensions_fragment<OtherUnit, OtherLevels...>>
&& (!has_level<this_top_level, hierarchy_dimensions_fragment<OtherUnit, OtherLevels...>>
else if constexpr (has_level<this_bottom_level, hierarchy_dimensions<OtherUnit, OtherLevels...>>
&& (!has_level<this_top_level, hierarchy_dimensions<OtherUnit, OtherLevels...>>
|| ::cuda::std::is_same_v<this_top_level, other_top_level>) )
{
// Overlap with this on the top, e.g. this is (grid, cluster), other is (cluster, block), can fully overlap
Expand All @@ -778,8 +772,8 @@ public:
else
{
// Overlap with this on the bottom, e.g. this is (cluster, block), other is (grid, cluster), can fully overlap
static_assert(has_level<other_bottom_level, hierarchy_dimensions_fragment<BottomUnit, Levels...>>
&& (!has_level<this_bottom_level, hierarchy_dimensions_fragment<OtherUnit, OtherLevels...>>
static_assert(has_level<other_bottom_level, hierarchy_dimensions<BottomUnit, Levels...>>
&& (!has_level<this_bottom_level, hierarchy_dimensions<OtherUnit, OtherLevels...>>
|| ::cuda::std::is_same_v<this_bottom_level, other_bottom_level>),
"Can't combine the hierarchies");

Expand All @@ -790,7 +784,7 @@ public:
}

# ifndef _CCCL_DOXYGEN_INVOKED // Do not document
constexpr hierarchy_dimensions_fragment combine([[maybe_unused]] __empty_hierarchy __empty) const
constexpr hierarchy_dimensions combine([[maybe_unused]] __empty_hierarchy __empty) const
{
return *this;
}
Expand Down Expand Up @@ -843,24 +837,13 @@ constexpr auto _CCCL_HOST get_launch_dimensions(const hierarchy_dimensions<Level
}
}

/* TODO consider having LUnit optional argument for template argument deduction
This could have been a single function with make_hierarchy and first template
argument defaulted, but then the above TODO would be impossible and the current
name makes more sense */
template <typename LUnit = void, typename L1, typename... Levels>
constexpr auto make_hierarchy_fragment(L1 l1, Levels... ls) noexcept
{
return detail::__make_hierarchy_fragment<LUnit>()(detail::__as_level(l1), detail::__as_level(ls)...);
}

// TODO consider having LUnit optional argument for template argument deduction
/**
* @brief Creates a hierarchy from passed in levels.
*
* This function takes any number of level_dimensions or derived objects
* and creates a hierarchy out of them. Levels need to be in ascending
* or descending order and the lowest level needs to be valid for thread_level unit.
* To create a hierarchy not ending with thread_level unit, use make_hierarchy_fragment
* instead.
*
* @par Snippet
* @code
Expand All @@ -874,10 +857,10 @@ constexpr auto make_hierarchy_fragment(L1 l1, Levels... ls) noexcept
* @endcode
* @par
*/
template <typename L1, typename... Levels>
template <typename LUnit = void, typename L1, typename... Levels>
constexpr auto make_hierarchy(L1 l1, Levels... ls) noexcept
{
return detail::__make_hierarchy_fragment<thread_level>()(detail::__as_level(l1), detail::__as_level(ls)...);
return detail::__make_hierarchy<LUnit>()(detail::__as_level(l1), detail::__as_level(ls)...);
}

/**
Expand All @@ -894,16 +877,16 @@ constexpr auto make_hierarchy(L1 l1, Levels... ls) noexcept
*
* using namespace cuda::experimental;
*
* auto partial1 = make_hierarchy_fragment<block_level>(grid_dims(256), cluster_dims<4>());
* auto partial1 = make_hierarchy<block_level>(grid_dims(256), cluster_dims<4>());
* auto hierarchy1 = hierarchy_add_level(partial1, block_dims<8, 8, 8>());
* auto partial2 = make_hierarchy_fragment<thread_level>(block_dims<8, 8, 8>(), cluster_dims<4>());
* auto partial2 = make_hierarchy<thread_level>(block_dims<8, 8, 8>(), cluster_dims<4>());
* auto hierarchy2 = hierarchy_add_level(partial2, grid_dims(256));
* static_assert(cuda::std::is_same_v<decltype(hierarchy1), decltype(hierarchy2)>);
* @endcode
* @par
*/
template <typename NewLevel, typename Unit, typename... Levels>
constexpr auto hierarchy_add_level(const hierarchy_dimensions_fragment<Unit, Levels...>& hierarchy, NewLevel lnew)
constexpr auto hierarchy_add_level(const hierarchy_dimensions<Unit, Levels...>& hierarchy, NewLevel lnew)
{
auto new_level = detail::__as_level(lnew);
using AddedLevel = decltype(new_level);
Expand All @@ -912,15 +895,15 @@ constexpr auto hierarchy_add_level(const hierarchy_dimensions_fragment<Unit, Lev

if constexpr (detail::can_rhs_stack_on_lhs<top_level, __level_type_of<AddedLevel>>)
{
return hierarchy_dimensions_fragment<Unit, AddedLevel, Levels...>(
return hierarchy_dimensions<Unit, AddedLevel, Levels...>(
::cuda::std::tuple_cat(::cuda::std::make_tuple(new_level), hierarchy.levels));
}
else
{
static_assert(detail::can_rhs_stack_on_lhs<__level_type_of<AddedLevel>, bottom_level>,
"Not supported order of levels in hierarchy");
using NewUnit = detail::__default_unit_below<__level_type_of<AddedLevel>>;
return hierarchy_dimensions_fragment<NewUnit, Levels..., AddedLevel>(
return hierarchy_dimensions<NewUnit, Levels..., AddedLevel>(
::cuda::std::tuple_cat(hierarchy.levels, ::cuda::std::make_tuple(new_level)));
}
}
Expand Down
10 changes: 5 additions & 5 deletions cudax/include/cuda/experimental/__launch/configuration.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ template <typename L1, typename Dims1, typename L2, typename Dims2>
_CUDAX_HOST_API constexpr auto
operator&(const level_dimensions<L1, Dims1>& l1, const level_dimensions<L2, Dims2>& l2) noexcept
{
return kernel_config(make_hierarchy_fragment(l1, l2));
return kernel_config(make_hierarchy(l1, l2));
}

template <typename _Dimensions, typename... _Options>
Expand Down Expand Up @@ -505,9 +505,9 @@ _CCCL_NODISCARD constexpr auto operator&(const hierarchy_dimensions<Levels...>&
*/
template <typename BottomUnit, typename... Levels, typename... Opts>
_CCCL_NODISCARD constexpr auto
make_config(const hierarchy_dimensions_fragment<BottomUnit, Levels...>& dims, const Opts&... opts) noexcept
make_config(const hierarchy_dimensions<BottomUnit, Levels...>& dims, const Opts&... opts) noexcept
{
return kernel_config<hierarchy_dimensions_fragment<BottomUnit, Levels...>, Opts...>(dims, opts...);
return kernel_config<hierarchy_dimensions<BottomUnit, Levels...>, Opts...>(dims, opts...);
}

/**
Expand Down Expand Up @@ -544,7 +544,7 @@ _CCCL_NODISCARD constexpr auto __process_config_args(const ::cuda::std::tuple<Pr
}
else
{
return kernel_config(::cuda::std::apply(make_hierarchy_fragment<void, const Prev&...>, previous));
return kernel_config(::cuda::std::apply(make_hierarchy<void, const Prev&...>, previous));
}
}

Expand All @@ -562,7 +562,7 @@ __process_config_args(const ::cuda::std::tuple<Prev...>& previous, const Arg& ar
}
else
{
return kernel_config(::cuda::std::apply(make_hierarchy_fragment<void, const Prev&...>, previous), arg, rest...);
return kernel_config(::cuda::std::apply(make_hierarchy<void, const Prev&...>, previous), arg, rest...);
}
}
else
Expand Down
16 changes: 8 additions & 8 deletions cudax/test/hierarchy/hierarchy_smoke.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,9 @@ TEST_CASE("Examples", "[hierarchy]")
static_assert(decltype(hierarchy.level(cluster).dims)::static_extent(0) == 4);
}
{
auto partial1 = make_hierarchy_fragment<block_level>(grid_dims(256), cluster_dims<4>());
auto partial1 = make_hierarchy<block_level>(grid_dims(256), cluster_dims<4>());
[[maybe_unused]] auto hierarchy1 = hierarchy_add_level(partial1, block_dims<8, 8, 8>());
auto partial2 = make_hierarchy_fragment<thread_level>(block_dims<8, 8, 8>(), cluster_dims<4>());
auto partial2 = make_hierarchy<thread_level>(block_dims<8, 8, 8>(), cluster_dims<4>());
[[maybe_unused]] auto hierarchy2 = hierarchy_add_level(partial2, grid_dims(256));
static_assert(cuda::std::is_same_v<decltype(hierarchy1), decltype(hierarchy2)>);
}
Expand Down Expand Up @@ -533,8 +533,8 @@ TEST_CASE("hierarchy merge", "[hierarchy]")
{
SECTION("Non overlapping")
{
auto h1 = cudax::make_hierarchy_fragment<cudax::block_level>(cudax::grid_dims<2>());
auto h2 = cudax::make_hierarchy_fragment<cudax::thread_level>(cudax::block_dims<3>());
auto h1 = cudax::make_hierarchy<cudax::block_level>(cudax::grid_dims<2>());
auto h2 = cudax::make_hierarchy<cudax::thread_level>(cudax::block_dims<3>());
auto combined = h1.combine(h2);
static_assert(combined.count(cudax::thread) == 6);
static_assert(combined.count(cudax::thread, cudax::block) == 3);
Expand All @@ -549,8 +549,8 @@ TEST_CASE("hierarchy merge", "[hierarchy]")
}
SECTION("Overlapping")
{
auto h1 = cudax::make_hierarchy_fragment<cudax::block_level>(cudax::grid_dims<2>(), cudax::cluster_dims<3>());
auto h2 = cudax::make_hierarchy_fragment<cudax::thread_level>(cudax::block_dims<4>(), cudax::cluster_dims<5>());
auto h1 = cudax::make_hierarchy<cudax::block_level>(cudax::grid_dims<2>(), cudax::cluster_dims<3>());
auto h2 = cudax::make_hierarchy<cudax::thread_level>(cudax::block_dims<4>(), cudax::cluster_dims<5>());
auto combined = h1.combine(h2);
static_assert(combined.count(cudax::thread) == 24);
static_assert(combined.count(cudax::thread, cudax::block) == 4);
Expand All @@ -566,13 +566,13 @@ TEST_CASE("hierarchy merge", "[hierarchy]")
static_assert(cuda::std::is_same_v<decltype(combined), decltype(ultimate_combination)>);
static_assert(ultimate_combination.count(cudax::thread) == 24);

auto block_level_replacement = cudax::make_hierarchy_fragment<cudax::thread_level>(cudax::block_dims<6>());
auto block_level_replacement = cudax::make_hierarchy<cudax::thread_level>(cudax::block_dims<6>());
auto with_block_replaced = block_level_replacement.combine(combined);
static_assert(with_block_replaced.count(cudax::thread) == 36);
static_assert(with_block_replaced.count(cudax::thread, cudax::block) == 6);

auto grid_cluster_level_replacement =
cudax::make_hierarchy_fragment<cudax::block_level>(cudax::grid_dims<7>(), cudax::cluster_dims<8>());
cudax::make_hierarchy<cudax::block_level>(cudax::grid_dims<7>(), cudax::cluster_dims<8>());
auto with_grid_cluster_replaced = grid_cluster_level_replacement.combine(combined);
static_assert(with_grid_cluster_replaced.count(cudax::thread) == 7 * 8 * 4);
static_assert(with_grid_cluster_replaced.count(cudax::block, cudax::cluster) == 8);
Expand Down

0 comments on commit 3a15507

Please sign in to comment.