29template <
reduce_op reduce_kind,
typename dtype_out,
typename dtype_acc,
30 int dim,
typename mat_t>
32 xetla_vector<dtype_out, mat_t::tile_size_y>>
34 static constexpr uint32_t tile_size_y = mat_t::tile_size_y;
35 static constexpr uint32_t tile_size_x = mat_t::tile_size_x;
36 static constexpr uint32_t block_size_y = mat_t::block_size_y;
37 static constexpr uint32_t block_size_x = mat_t::block_size_x;
38 static constexpr uint32_t block_elems = mat_t::block_elems;
39 static constexpr int32_t num_block_x = mat_t::num_block_x;
40 using dtype =
typename mat_t::dtype;
47 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
50 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
51 (i * num_block_x) * block_elems);
53 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
55 = acc.xetla_select<block_elems, 1>(i * block_elems);
56 dst_reg_acc = src_reg_acc;
59 for (uint32_t j = 1; j < num_block_x; j++) {
60 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
61 (i * num_block_x + j) * block_elems);
63 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
65 = acc.xetla_select<block_elems, 1>(i * block_elems);
66 dst_reg_acc = reduce_helper<reduce_kind, dtype_acc, block_elems>(
67 dst_reg_acc, src_reg_acc);
72 if constexpr ((tile_size_y % block_size_y) != 0) {
73 constexpr uint32_t tail_start_y
74 = tile_size_y / block_size_y * block_size_y;
75 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
76 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
79 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
80 tail_start_y * tile_size_x);
82 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
83 auto dst_reg_acc = acc.xetla_select<tail_block_elems, 1>(
84 tail_start_y * block_size_x);
85 dst_reg_acc = src_reg_acc;
88 for (uint32_t j = 1; j < num_block_x; j++) {
89 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
90 tail_start_y * tile_size_x + j * tail_block_elems);
92 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
93 auto dst_reg_acc = acc.xetla_select<tail_block_elems, 1>(
94 tail_start_y * block_size_x);
96 = reduce_helper<reduce_kind, dtype_acc, tail_block_elems>(
97 dst_reg_acc, src_reg_acc);
102 dtype_acc, block_size_x, tile_size_y>(acc);
104 return xetla_cvt<dtype_out, dtype_acc, tile_size_y>(out);
107template <
reduce_op reduce_kind,
typename dtype_out,
typename dtype_acc,
108 int dim,
typename mat_t>
112 static constexpr uint32_t tile_size_y = mat_t::tile_size_y;
113 static constexpr uint32_t tile_size_x = mat_t::tile_size_x;
114 static constexpr uint32_t block_size_y = mat_t::block_size_y;
115 static constexpr uint32_t block_size_x = mat_t::block_size_x;
116 static constexpr uint32_t block_elems = mat_t::block_elems;
117 static constexpr int32_t num_block_x = mat_t::num_block_x;
118 using dtype =
typename mat_t::dtype;
119 static constexpr uint32_t num_acc = 8;
120 static constexpr uint32_t first_block_size_y
121 = (tile_size_y / block_size_y == 0) ? (tile_size_y % block_size_y)
123 static constexpr uint32_t acc_size_y
124 = (num_acc > first_block_size_y) ? first_block_size_y : num_acc;
131 static constexpr uint32_t first_block_elems
132 = first_block_size_y * block_size_x;
133 static constexpr uint32_t acc_block_elems = acc_size_y * block_size_x;
136 for (uint32_t j = 0; j < num_block_x; j++) {
137 auto src_reg = (src.reg).xetla_select<first_block_elems, 1>(
138 j * first_block_elems);
140 = xetla_cvt<dtype_acc, dtype, first_block_elems>(src_reg);
141 acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems)
142 = src_reg_acc.xetla_select<acc_block_elems, 1>(0);
146 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
148 for (uint32_t j = 0; j < num_block_x; j++) {
149 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
150 (i * num_block_x + j) * block_elems);
152 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
154 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
156 for (uint32_t row_i = 0; row_i < block_size_y / acc_size_y;
158 if (i == 0 && row_i == 0)
continue;
160 acc_block_elems>(dst_reg_acc,
161 src_reg_acc.xetla_select<acc_block_elems, 1>(
162 row_i * acc_block_elems));
165 if constexpr ((block_size_y % acc_size_y) != 0) {
166 constexpr uint32_t acc_tail_start_y
167 = block_size_y / acc_size_y * acc_size_y;
168 constexpr uint32_t acc_tail_size_y = block_size_y % acc_size_y;
169 constexpr uint32_t acc_tail_block_elems
170 = acc_tail_size_y * block_size_x;
171 auto dst_reg_acc_tail
172 = dst_reg_acc.xetla_select<acc_tail_block_elems>(0);
174 acc_tail_block_elems>(dst_reg_acc_tail,
175 src_reg_acc.xetla_select<acc_tail_block_elems, 1>(
176 acc_tail_start_y * block_size_x));
182 if constexpr ((tile_size_y % block_size_y) != 0) {
183 constexpr uint32_t tail_start_y
184 = tile_size_y / block_size_y * block_size_y;
185 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
186 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
188 for (uint32_t j = 0; j < num_block_x; j++) {
189 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
190 tail_start_y * tile_size_x + j * tail_block_elems);
192 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
194 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
196 for (uint32_t row_i = 0; row_i < tail_size_y / acc_size_y;
198 if ((tile_size_y / block_size_y == 0) && row_i == 0)
continue;
200 acc_block_elems>(dst_reg_acc,
201 src_reg_acc.xetla_select<acc_block_elems, 1>(
202 row_i * acc_block_elems));
205 if constexpr ((tail_size_y % acc_size_y) != 0) {
206 constexpr uint32_t acc_tail_start_y
207 = tail_size_y / acc_size_y * acc_size_y;
208 constexpr uint32_t acc_tail_size_y = tail_size_y % acc_size_y;
209 constexpr uint32_t acc_tail_block_elems
210 = acc_tail_size_y * block_size_x;
211 auto dst_reg_acc_tail
212 = dst_reg_acc.xetla_select<acc_tail_block_elems>(0);
214 acc_tail_block_elems>(dst_reg_acc_tail,
215 src_reg_acc.xetla_select<acc_tail_block_elems, 1>(
216 acc_tail_start_y * block_size_x));
223 for (uint32_t i = 0; i < acc_size_y; i++) {
225 for (uint32_t j = 0; j < num_block_x; j++) {
227 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
228 auto reg_acc_2d = reg_acc.xetla_format<dtype_acc, acc_size_y,
231 out.xetla_select<block_size_x, 1>(j * block_size_x)
234 out.xetla_select<block_size_x, 1>(j * block_size_x)
235 = reduce_helper<reduce_kind, dtype_acc, block_size_x>(
236 out.xetla_select<block_size_x, 1>(
243 return xetla_cvt<dtype_out, dtype_acc, tile_size_x>(out);
257template <
typename T_dst,
typename T_src,
bool accumulate =
true,
258 typename dtype_acc = float, uint32_t num_acc = 4>
260 "This is only for reduce add, and will be deprecated in future. "
261 "Please use tile_reduce instead.")
263 typename std::enable_if_t<(T_dst::register_layout ==
reg_layout::
tiled)
265 && (T_dst::tile_size_x == T_src::tile_size_x)
266 && (T_dst::tile_size_y == 1)> tile_row_reduce(T_dst &dst,
268 static constexpr uint32_t tile_size_y = T_src::tile_size_y;
269 static constexpr uint32_t tile_size_x = T_src::tile_size_x;
270 static constexpr uint32_t block_size_y = T_src::block_size_y;
271 static constexpr uint32_t block_size_x = T_src::block_size_x;
272 static constexpr uint32_t block_elems = T_src::block_elems;
273 static constexpr int32_t num_block_x = T_src::num_block_x;
274 using dtype_dst =
typename T_dst::dtype;
275 using dtype_src =
typename T_src::dtype;
277 static constexpr uint32_t
SIMD = 64 /
sizeof(dtype_acc);
278 static constexpr uint32_t accum_len
279 = ((block_size_x %
SIMD) && (
sizeof(dtype_src) < 4)) == 0
289 auto acc_2d = acc.xetla_format<dtype_acc, num_acc, tile_size_x>();
290 if constexpr (accumulate) {
291 acc_2d.row(0) = xetla_cvt<dtype_acc, dtype_dst, tile_size_x>(dst.reg);
294 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
296 for (uint32_t j = 0; j < num_block_x; j++) {
297 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
298 (i * num_block_x + j) * block_elems);
299 auto src_reg_dtype_acc
300 = xetla_cvt<dtype_acc, dtype_src, block_elems>(src_reg);
301 auto src_reg_2d = src_reg_dtype_acc.xetla_format<dtype_acc,
302 block_size_y, block_size_x>();
304 for (uint32_t row_i = 0; row_i < block_size_y; row_i += num_acc) {
306 for (uint32_t acc_i = 0;
307 (acc_i < num_acc) && (row_i + acc_i < block_size_y);
309 auto acc_sub = acc.xetla_select<block_size_x, 1>(
310 acc_i * tile_size_x + j * block_size_x);
312 for (uint32_t k = 0; k < block_size_x / accum_len; k++) {
313 acc_sub.xetla_select<accum_len, 1>(k * accum_len)
314 = acc_sub.xetla_select<accum_len, 1>(
316 + src_reg_2d.row(row_i + acc_i)
317 .xetla_select<accum_len, 1>(
326 if constexpr ((tile_size_y % block_size_y) != 0) {
327 constexpr uint32_t tail_start_y
328 = tile_size_y / block_size_y * block_size_y;
329 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
330 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
332 for (uint32_t j = 0; j < num_block_x; j++) {
333 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
334 tail_start_y * tile_size_x + j * tail_block_elems);
335 auto src_reg_dtype_acc
336 = xetla_cvt<dtype_acc, dtype_src, tail_block_elems>(
338 auto src_reg_2d = src_reg_dtype_acc.xetla_format<dtype_acc,
339 tail_size_y, block_size_x>();
341 for (uint32_t row_i = 0; row_i < tail_size_y; row_i += num_acc) {
343 for (uint32_t acc_i = 0;
344 (acc_i < num_acc) && (row_i + acc_i < tail_size_y);
346 auto acc_sub = acc.xetla_select<block_size_x, 1>(
347 acc_i * tile_size_x + j * block_size_x);
349 for (uint32_t k = 0; k < block_size_x / accum_len; k++) {
350 acc_sub.xetla_select<accum_len, 1>(k * accum_len)
351 = acc_sub.xetla_select<accum_len, 1>(
353 + src_reg_2d.row(row_i + acc_i)
354 .xetla_select<accum_len, 1>(
363 for (uint32_t i = 1; i < num_acc; i++) {
364 acc_2d.row(0) += acc_2d.row(i);
367 dst.reg = xetla_cvt<dtype_dst, dtype_acc, tile_size_x>(acc_2d.row(0));
#define XETLA_MARKER(message)
Definition common.hpp:53
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API std::enable_if_t< N_x==1, xetla_vector< dtype, N_y > > recur_col_reduce(xetla_vector< dtype, N_y > in)
Definition misc.hpp:162
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
Definition limitation.hpp:457
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
reduce_op
xetla reduce op
Definition common.hpp:217