32template <
typename matAcc_t_,
mem_layout mem_layout_,
33 uint32_t num_cooperative_wg,
gpu_arch arch_tag_,
class enable =
void>
37template <
typename matAcc_t_, u
int32_t num_cooperative_wg, gpu_arch arch_tag_>
39 num_cooperative_wg, arch_tag_,
40 std::enable_if_t<gpu_arch::Xe == arch_tag_>> {
42 static constexpr gpu_arch arch_tag = arch_tag_;
44 using dtype =
typename matAcc_t::dtype;
50 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
51 "num_cooperative_wg should be power of 2");
61 static constexpr uint32_t coop_num_y
63 src_tile_size_y>::value;
64 static constexpr uint32_t coop_remain_num_x
65 = num_cooperative_wg / coop_num_y;
66 static constexpr bool has_redundant_wg
67 = (coop_remain_num_x * 16) > src_tile_size_x;
68 static constexpr uint32_t tile_size_y = src_tile_size_y / coop_num_y;
69 static constexpr uint32_t tile_size_x
70 = has_redundant_wg ? 16 : src_tile_size_x / coop_remain_num_x;
71 static constexpr uint32_t coop_num_x = src_tile_size_x / tile_size_x;
74 static constexpr uint32_t block_size_x
76 src_block_size_x>::value;
77 static constexpr uint32_t block_size_y
78 = (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y;
87 return coop_id % coop_remain_num_x * tile_size_x;
91 return coop_id / coop_remain_num_x * tile_size_y;
96template <
typename matAcc_t_, u
int32_t num_cooperative_wg, gpu_arch arch_tag_>
98 num_cooperative_wg, arch_tag_,
99 std::enable_if_t<gpu_arch::Xe == arch_tag_>> {
103 using dtype =
typename matAcc_t::dtype;
109 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
110 "num_cooperative_wg should be power of 2");
118 static constexpr uint32_t coop_num_x
120 src_tile_size_x>::value;
121 static constexpr uint32_t coop_remain_num_y
122 = num_cooperative_wg / coop_num_x;
123 static constexpr bool has_redundant_wg
124 = (coop_remain_num_y * 16) > src_tile_size_y;
125 static constexpr uint32_t tile_size_x = src_tile_size_x / coop_num_x;
126 static constexpr uint32_t tile_size_y
127 = has_redundant_wg ? 16 : src_tile_size_y / coop_remain_num_y;
128 static constexpr uint32_t coop_num_y = src_tile_size_y / tile_size_y;
131 static constexpr uint32_t block_size_y
133 src_block_size_y>::value;
134 static constexpr uint32_t block_size_x
135 = (tile_size_x > src_block_size_x) ? src_block_size_x : tile_size_x;
144 return coop_id / coop_remain_num_y * tile_size_x;
148 return coop_id % coop_remain_num_y * tile_size_y;
static int32_t get_offset_y(uint32_t coop_id)
Definition cooperative_load_helper.hpp:90
static int32_t get_offset_x(uint32_t coop_id)
Definition cooperative_load_helper.hpp:86
cooperative_load_helper_t()=default
matAcc_t_ matAcc_t
Definition cooperative_load_helper.hpp:43
typename matAcc_t::tile_desc tile_desc_t
Definition cooperative_load_helper.hpp:45
typename matAcc_t::dtype dtype
Definition cooperative_load_helper.hpp:44
cooperative_load_helper_t()=default
matAcc_t_ matAcc_t
Definition cooperative_load_helper.hpp:102
typename matAcc_t::dtype dtype
Definition cooperative_load_helper.hpp:103
static int32_t get_offset_x(uint32_t coop_id)
Definition cooperative_load_helper.hpp:143
static int32_t get_offset_y(uint32_t coop_id)
Definition cooperative_load_helper.hpp:147
typename matAcc_t::tile_desc tile_desc_t
Definition cooperative_load_helper.hpp:104
Helper to do the cooperative workgroups load.
Definition cooperative_load_helper.hpp:34
Definition limitation.hpp:457
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
static constexpr uint32_t tile_size_y
Definition api.hpp:66
static constexpr uint32_t block_size_x
Definition api.hpp:68
static constexpr uint32_t tile_size_x
Definition api.hpp:65
static constexpr uint32_t block_size_y
Definition api.hpp:69