23#include "subgroup/subgroup.hpp"
33template <
reduce_op reduce_kind,
typename tile_shape,
typename matAcc_t,
34 uint32_t num_cooperative_wg,
gpu_arch arch_tag,
class enable =
void>
38template <
reduce_op reduce_kind,
typename tile_shape_,
typename matAcc_t,
39 uint32_t num_cooperative_wg,
gpu_arch arch_tag_>
41 num_cooperative_wg, arch_tag_,
42 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
44 static constexpr gpu_arch arch_tag = arch_tag_;
46 using dtype =
typename matAcc_t::dtype;
49 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
50 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
51 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
52 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
53 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
54 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
56 static constexpr uint32_t wg_tile_size
57 = real_wg_tile_m * real_wg_tile_n *
sizeof(
dtype);
58 using work_group_t =
typename tile_shape::work_group_t;
61 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
62 "num_cooperative_wg should be power of 2");
65 static constexpr uint32_t coop_num_y
68 static constexpr uint32_t coop_remain_num_x
69 = num_cooperative_wg / coop_num_y;
70 static constexpr bool has_redundant_wg
71 = (coop_remain_num_x * 16) > sg_tile_n;
72 static constexpr uint32_t tile_size_y = sg_tile_m / coop_num_y;
73 static constexpr uint32_t tile_size_x
74 = has_redundant_wg ? 16 : sg_tile_n / coop_remain_num_x;
75 static constexpr uint32_t coop_num_x = sg_tile_n / tile_size_x;
76 static constexpr uint32_t num_reduce_wg = coop_num_x * coop_num_y;
79 static constexpr uint32_t src_block_size_x = matAcc_t::block_size_x;
80 static constexpr uint32_t src_block_size_y = matAcc_t::block_size_y;
82 static constexpr uint32_t block_size_x
84 src_block_size_x>::value;
85 static constexpr uint32_t block_size_y
86 = (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y;
94 subgroup::msg_type_v<local_st_tile_desc_t, mem_space::local>,
102 subgroup::msg_type_v<local_ld_tile_desc_t, mem_space::local>,
108 static constexpr uint32_t barrier_count = work_group_size;
109 static constexpr uint32_t slm_size = wg_tile_size * num_cooperative_wg;
115 coop_id_x = coop_id % coop_remain_num_x;
116 coop_id_y = coop_id / coop_remain_num_x;
131 matAcc_t &matAcc, uint32_t slm_base = 0,
132 uint32_t nbarrier_base = 0) {
133 uint32_t sg_idx = g.get_id() % wg_size_x;
134 uint32_t sg_idy = g.get_id() / wg_size_x;
136 int32_t slm_store_offset_x = sg_idx * sg_tile_n;
137 int32_t slm_store_offset_y
138 = coop_id * real_wg_tile_m + sg_idy * sg_tile_m;
141 real_wg_tile_m * num_cooperative_wg, real_wg_tile_n,
142 slm_store_offset_x, slm_store_offset_y);
143 local_st.
reg = matAcc.reg;
144 tile_store(local_st, local_st_payload);
148 uint32_t nbar_id = nbarrier_base + g.get_id();
150 xetla_fence<memory_kind::shared_local>();
154 if (is_valid_post_process_wg()) {
157 int32_t slm_load_offset_x
158 = sg_idx * sg_tile_n + coop_id_x * tile_size_x;
159 int32_t slm_load_offset_y
160 = sg_idy * sg_tile_m + coop_id_y * tile_size_y;
164 real_wg_tile_m * num_cooperative_wg, real_wg_tile_n,
165 slm_load_offset_x, slm_load_offset_y);
167 tile_load(local_ld, local_ld_payload);
168 mat_slice.
reg = local_ld.
reg;
170 for (uint32_t i = 1; i < num_cooperative_wg; i++) {
171 local_ld_payload.template update_tdesc<tdesc_update_dir::y_dir>(
173 tile_load(local_ld, local_ld_payload);
174 mat_slice.
reg = reduce_helper<reduce_kind, dtype>(
175 mat_slice.
reg, local_ld.
reg);
183template <
reduce_op reduce_kind,
typename tile_shape_,
typename matAcc_t,
186 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
190 using dtype =
typename matAcc_t::dtype;
193 using work_group_t =
typename tile_shape::work_group_t;
197 static constexpr uint32_t barrier_count = 0;
198 static constexpr uint32_t slm_size = 0;
199 static constexpr uint32_t coop_num_x = 1;
200 static constexpr uint32_t coop_num_y = 1;
213 [[maybe_unused]] uint32_t slm_base = 0,
214 [[maybe_unused]] uint32_t nbarrier_base = 0) {
215 mat_slice.reg = matAcc.reg;
uint32_t coop_id
Definition cooperative_reduction.hpp:111
KERNEL_FUNC void operator()(work_group_t &g, mat_slice_t &mat_slice, matAcc_t &matAcc, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Cooperative workgroup reduction.
Definition cooperative_reduction.hpp:130
typename matAcc_t::dtype dtype
Definition cooperative_reduction.hpp:46
tile_shape_ tile_shape
Definition cooperative_reduction.hpp:45
uint32_t coop_id_x
Definition cooperative_reduction.hpp:112
cooperative_reduce_t(uint32_t coop_id_)
Definition cooperative_reduction.hpp:114
uint32_t coop_id_y
Definition cooperative_reduction.hpp:113
bool is_valid_post_process_wg()
Definition cooperative_reduction.hpp:118
bool is_valid_post_process_wg()
Definition cooperative_reduction.hpp:209
uint32_t coop_id_x
Definition cooperative_reduction.hpp:202
matAcc_t mat_slice_t
Definition cooperative_reduction.hpp:196
KERNEL_FUNC void operator()(work_group_t &g, mat_slice_t &mat_slice, matAcc_t &matAcc, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition cooperative_reduction.hpp:211
uint32_t coop_id_y
Definition cooperative_reduction.hpp:203
uint32_t coop_id
Definition cooperative_reduction.hpp:201
tile_shape_ tile_shape
Definition cooperative_reduction.hpp:189
cooperative_reduce_t(uint32_t coop_id_)
Definition cooperative_reduction.hpp:204
typename matAcc_t::dtype dtype
Definition cooperative_reduction.hpp:190
Workgroups to do the cooperative reduction.
Definition cooperative_reduction.hpp:35
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
static constexpr uint32_t size
Definition work_group.hpp:39
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76