23#include "subgroup/subgroup.hpp"
36template <
reduce_op reduce_kind,
typename tile_shape_acc,
37 typename tile_shape_cnt,
typename mem_desc_acc_t,
38 typename mem_desc_cnt_t, uint32_t num_group_reduction,
39 uint32_t counter_size,
gpu_arch arch_tag,
class enable =
void>
43template <
typename tile_shape_acc_,
typename tile_shape_cnt_,
44 typename mem_desc_acc_t_,
typename mem_desc_cnt_t_,
45 uint32_t num_group_reduction, uint32_t counter_size,
gpu_arch arch_tag_>
47 mem_desc_acc_t_, mem_desc_cnt_t_, num_group_reduction, counter_size,
48 arch_tag_, std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
50 static constexpr gpu_arch arch_tag = arch_tag_;
59 static constexpr uint32_t acc_sg_tile_y = tile_shape_acc::sg_tile_size_y;
60 static constexpr uint32_t acc_sg_tile_x = tile_shape_acc::sg_tile_size_x;
61 static constexpr uint32_t cnt_sg_tile_y = tile_shape_cnt::sg_tile_size_y;
62 static constexpr uint32_t cnt_sg_tile_x = tile_shape_cnt::sg_tile_size_x;
63 static constexpr uint32_t wg_size_x = tile_shape_acc::wg_size_x;
64 static constexpr uint32_t wg_size_y = tile_shape_acc::wg_size_y;
65 static_assert((tile_shape_acc::wg_size_x == tile_shape_cnt::wg_size_x)
66 && (tile_shape_acc::wg_size_y == tile_shape_cnt::wg_size_y),
67 "acc and cnt wg shape need to be matched");
68 using work_group_t =
typename tile_shape_acc::work_group_t;
71 inline void update_sg_tile_tdesc(work_group_t &g,
73 int32_t sg_idx = g.get_id() % wg_size_x;
74 int32_t sg_idy = g.get_id() / wg_size_x;
75 int32_t acc_tile_offset_x = sg_idx * acc_sg_tile_x;
76 int32_t acc_tile_offset_y = sg_idy * acc_sg_tile_y;
77 mem_desc_acc.update_coord(acc_tile_offset_x, acc_tile_offset_y);
78 int32_t cnt_tile_offset_x = sg_idx * cnt_sg_tile_x;
79 int32_t cnt_tile_offset_y = sg_idy * cnt_sg_tile_y;
80 mem_desc_cnt.update_coord(cnt_tile_offset_x, cnt_tile_offset_y);
83 inline uint32_t update_reduce_counter(mem_desc_cnt_t &mem_desc_cnt) {
84 constexpr uint32_t
SIMD = 16;
85 uint32_t pitch_in_bytes
86 = mem_desc_cnt.shape.stride *
sizeof(dtype_cnt) * counter_size;
87 uint32_t offset_x = mem_desc_cnt.coord.x;
88 uint32_t offset_y = mem_desc_cnt.coord.y;
89 uint64_t address = (uint64_t)mem_desc_cnt.base.base
90 + offset_y * pitch_in_bytes
91 + offset_x *
sizeof(dtype_cnt) * counter_size;
93 = xetla_vector_gen<uint32_t, SIMD>(0, 1);
94 offsets *=
sizeof(dtype_cnt);
103 inline void clean_reduce_counter(mem_desc_cnt_t &mem_desc_cnt) {
104 uint32_t pitch_in_bytes
105 = mem_desc_cnt.shape.stride *
sizeof(dtype_cnt) * counter_size;
106 uint32_t offset_x = mem_desc_cnt.coord.x;
107 uint32_t offset_y = mem_desc_cnt.coord.y;
108 uint64_t address = (uint64_t)mem_desc_cnt.base.base
109 + offset_y * pitch_in_bytes
110 + offset_x *
sizeof(dtype_cnt) * counter_size;
115 (dtype_cnt *)address, 0, zeros);
119 static constexpr uint32_t barrier_count = 0;
120 static constexpr uint32_t slm_size = 0;
121 uint32_t reduce_id = 0;
124 return reduce_id == (num_group_reduction - 1);
138 template <
typename matAcc_t>
141 [[maybe_unused]] uint32_t slm_base = 0,
142 [[maybe_unused]] uint32_t nbarrier_base = 0) {
143 static_assert(std::is_same<typename matAcc_t::dtype, dtype_acc>::value,
144 "matAcc_t::dtype should match with dtype_acc");
145 update_sg_tile_tdesc(g, mem_desc_acc, mem_desc_cnt);
146 using matAcc_tile_desc_t =
typename matAcc_t::tile_desc;
149 matAcc_store_payload_t matAcc_store_payload(mem_desc_acc);
150 subgroup::tile_store<cache_hint::uncached, cache_hint::write_back>(
151 matAcc, matAcc_store_payload);
154 reduce_id = update_reduce_counter(mem_desc_cnt);
155 if (reduce_id == (num_group_reduction - 1)) {
158 matAcc_payload_t matAcc_payload(mem_desc_acc);
160 clean_reduce_counter(mem_desc_cnt);
164 subgroup::tile_store<cache_hint::uncached, cache_hint::write_back>(
165 mat_zero, matAcc_payload);
172template <
typename tile_shape_acc_,
typename tile_shape_cnt_,
173 typename mem_desc_acc_t_,
typename mem_desc_cnt_t_,
174 uint32_t counter_size_,
gpu_arch arch_tag_>
176 mem_desc_acc_t_, mem_desc_cnt_t_, 1, counter_size_, arch_tag_,
177 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
187 using work_group_t =
typename tile_shape_acc::work_group_t;
190 static constexpr uint32_t barrier_count = 0;
191 static constexpr uint32_t slm_size = 0;
194 template <
typename matAcc_t>
196 [[maybe_unused]] matAcc_t &matAcc,
199 [[maybe_unused]] uint32_t slm_base = 0,
200 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_acc_t mem_desc_acc, mem_desc_cnt_t mem_desc_cnt, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition global_reduction.hpp:195
mem_desc_cnt_t_ mem_desc_cnt_t
Definition global_reduction.hpp:183
mem_desc_acc_t_ mem_desc_acc_t
Definition global_reduction.hpp:182
tile_shape_acc_ tile_shape_acc
Definition global_reduction.hpp:180
bool is_last_group()
Definition global_reduction.hpp:192
typename mem_desc_acc_t::dtype dtype_acc
Definition global_reduction.hpp:184
tile_shape_cnt_ tile_shape_cnt
Definition global_reduction.hpp:181
mem_desc_acc_t_ mem_desc_acc_t
Definition global_reduction.hpp:53
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_acc_t mem_desc_acc, mem_desc_cnt_t mem_desc_cnt, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Global reduction.
Definition global_reduction.hpp:139
bool is_last_group()
Definition global_reduction.hpp:123
mem_desc_cnt_t_ mem_desc_cnt_t
Definition global_reduction.hpp:54
tile_shape_acc_ tile_shape_acc
Definition global_reduction.hpp:51
tile_shape_cnt_ tile_shape_cnt
Definition global_reduction.hpp:52
typename mem_desc_acc_t::dtype dtype_acc
Definition global_reduction.hpp:55
typename mem_desc_cnt_t::dtype dtype_cnt
Definition global_reduction.hpp:56
Cross group global reduction.
Definition global_reduction.hpp:40
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
__XETLA_API xetla_vector< T, N > xetla_atomic_global(T *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred)
Stateless scattered atomic (0 src).
Definition memory.hpp:371
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
@ tile
flush out to the local scope
reduce_op
xetla reduce op
Definition common.hpp:217
@ iinc
Atomic increment of memory data and return the old value. see
gpu_arch
Definition common.hpp:73
Is to illustrate the memory information.
Definition api.hpp:44
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102