XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
reduction.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
24
25namespace gpu::xetla::subgroup {
26
27// dim=0 : reduce along y dir;
28// dim=1 : reduce along x dir;
29template <reduce_op reduce_kind, typename dtype_out, typename dtype_acc,
30 int dim, typename mat_t>
31__XETLA_API typename std::enable_if_t<(dim == 1),
32 xetla_vector<dtype_out, mat_t::tile_size_y>>
33tile_reduce(mat_t &src) {
34 static constexpr uint32_t tile_size_y = mat_t::tile_size_y;
35 static constexpr uint32_t tile_size_x = mat_t::tile_size_x;
36 static constexpr uint32_t block_size_y = mat_t::block_size_y;
37 static constexpr uint32_t block_size_x = mat_t::block_size_x;
38 static constexpr uint32_t block_elems = mat_t::block_elems;
39 static constexpr int32_t num_block_x = mat_t::num_block_x;
40 using dtype = typename mat_t::dtype;
46#pragma unroll
47 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
48 //j=0, initial the buffer
49 {
50 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
51 (i * num_block_x) * block_elems);
52 auto src_reg_acc
53 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
54 auto dst_reg_acc
55 = acc.xetla_select<block_elems, 1>(i * block_elems);
56 dst_reg_acc = src_reg_acc;
57 }
58#pragma unroll
59 for (uint32_t j = 1; j < num_block_x; j++) {
60 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
61 (i * num_block_x + j) * block_elems);
62 auto src_reg_acc
63 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
64 auto dst_reg_acc
65 = acc.xetla_select<block_elems, 1>(i * block_elems);
66 dst_reg_acc = reduce_helper<reduce_kind, dtype_acc, block_elems>(
67 dst_reg_acc, src_reg_acc);
68 }
69 }
70
71 // process the tail
72 if constexpr ((tile_size_y % block_size_y) != 0) {
73 constexpr uint32_t tail_start_y
74 = tile_size_y / block_size_y * block_size_y;
75 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
76 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
77 //j=0, initial the buffer
78 {
79 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
80 tail_start_y * tile_size_x);
81 auto src_reg_acc
82 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
83 auto dst_reg_acc = acc.xetla_select<tail_block_elems, 1>(
84 tail_start_y * block_size_x);
85 dst_reg_acc = src_reg_acc;
86 }
87#pragma unroll
88 for (uint32_t j = 1; j < num_block_x; j++) {
89 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
90 tail_start_y * tile_size_x + j * tail_block_elems);
91 auto src_reg_acc
92 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
93 auto dst_reg_acc = acc.xetla_select<tail_block_elems, 1>(
94 tail_start_y * block_size_x);
95 dst_reg_acc
96 = reduce_helper<reduce_kind, dtype_acc, tail_block_elems>(
97 dst_reg_acc, src_reg_acc);
98 }
99 }
100
102 dtype_acc, block_size_x, tile_size_y>(acc);
103
104 return xetla_cvt<dtype_out, dtype_acc, tile_size_y>(out);
105}
106
107template <reduce_op reduce_kind, typename dtype_out, typename dtype_acc,
108 int dim, typename mat_t>
109__XETLA_API typename std::enable_if_t<(dim == 0),
111tile_reduce(mat_t &src) {
112 static constexpr uint32_t tile_size_y = mat_t::tile_size_y;
113 static constexpr uint32_t tile_size_x = mat_t::tile_size_x;
114 static constexpr uint32_t block_size_y = mat_t::block_size_y;
115 static constexpr uint32_t block_size_x = mat_t::block_size_x;
116 static constexpr uint32_t block_elems = mat_t::block_elems;
117 static constexpr int32_t num_block_x = mat_t::num_block_x;
118 using dtype = typename mat_t::dtype;
119 static constexpr uint32_t num_acc = 8;
120 static constexpr uint32_t first_block_size_y
121 = (tile_size_y / block_size_y == 0) ? (tile_size_y % block_size_y)
122 : block_size_y;
123 static constexpr uint32_t acc_size_y
124 = (num_acc > first_block_size_y) ? first_block_size_y : num_acc;
131 static constexpr uint32_t first_block_elems
132 = first_block_size_y * block_size_x;
133 static constexpr uint32_t acc_block_elems = acc_size_y * block_size_x;
135#pragma unroll
136 for (uint32_t j = 0; j < num_block_x; j++) {
137 auto src_reg = (src.reg).xetla_select<first_block_elems, 1>(
138 j * first_block_elems);
139 auto src_reg_acc
140 = xetla_cvt<dtype_acc, dtype, first_block_elems>(src_reg);
141 acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems)
142 = src_reg_acc.xetla_select<acc_block_elems, 1>(0);
143 }
144
145#pragma unroll
146 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
147#pragma unroll
148 for (uint32_t j = 0; j < num_block_x; j++) {
149 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
150 (i * num_block_x + j) * block_elems);
151 auto src_reg_acc
152 = xetla_cvt<dtype_acc, dtype, block_elems>(src_reg);
153 auto dst_reg_acc
154 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
155#pragma unroll
156 for (uint32_t row_i = 0; row_i < block_size_y / acc_size_y;
157 row_i++) {
158 if (i == 0 && row_i == 0) continue;
159 dst_reg_acc = reduce_helper<reduce_kind, dtype_acc,
160 acc_block_elems>(dst_reg_acc,
161 src_reg_acc.xetla_select<acc_block_elems, 1>(
162 row_i * acc_block_elems));
163 }
164 // process the tail
165 if constexpr ((block_size_y % acc_size_y) != 0) {
166 constexpr uint32_t acc_tail_start_y
167 = block_size_y / acc_size_y * acc_size_y;
168 constexpr uint32_t acc_tail_size_y = block_size_y % acc_size_y;
169 constexpr uint32_t acc_tail_block_elems
170 = acc_tail_size_y * block_size_x;
171 auto dst_reg_acc_tail
172 = dst_reg_acc.xetla_select<acc_tail_block_elems>(0);
173 dst_reg_acc_tail = reduce_helper<reduce_kind, dtype_acc,
174 acc_tail_block_elems>(dst_reg_acc_tail,
175 src_reg_acc.xetla_select<acc_tail_block_elems, 1>(
176 acc_tail_start_y * block_size_x));
177 }
178 }
179 }
180
181 // process the tail
182 if constexpr ((tile_size_y % block_size_y) != 0) {
183 constexpr uint32_t tail_start_y
184 = tile_size_y / block_size_y * block_size_y;
185 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
186 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
187#pragma unroll
188 for (uint32_t j = 0; j < num_block_x; j++) {
189 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
190 tail_start_y * tile_size_x + j * tail_block_elems);
191 auto src_reg_acc
192 = xetla_cvt<dtype_acc, dtype, tail_block_elems>(src_reg);
193 auto dst_reg_acc
194 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
195#pragma unroll
196 for (uint32_t row_i = 0; row_i < tail_size_y / acc_size_y;
197 row_i++) {
198 if ((tile_size_y / block_size_y == 0) && row_i == 0) continue;
199 dst_reg_acc = reduce_helper<reduce_kind, dtype_acc,
200 acc_block_elems>(dst_reg_acc,
201 src_reg_acc.xetla_select<acc_block_elems, 1>(
202 row_i * acc_block_elems));
203 }
204 // process the tail
205 if constexpr ((tail_size_y % acc_size_y) != 0) {
206 constexpr uint32_t acc_tail_start_y
207 = tail_size_y / acc_size_y * acc_size_y;
208 constexpr uint32_t acc_tail_size_y = tail_size_y % acc_size_y;
209 constexpr uint32_t acc_tail_block_elems
210 = acc_tail_size_y * block_size_x;
211 auto dst_reg_acc_tail
212 = dst_reg_acc.xetla_select<acc_tail_block_elems>(0);
213 dst_reg_acc_tail = reduce_helper<reduce_kind, dtype_acc,
214 acc_tail_block_elems>(dst_reg_acc_tail,
215 src_reg_acc.xetla_select<acc_tail_block_elems, 1>(
216 acc_tail_start_y * block_size_x));
217 }
218 }
219 }
220
222#pragma unroll
223 for (uint32_t i = 0; i < acc_size_y; i++) {
224#pragma unroll
225 for (uint32_t j = 0; j < num_block_x; j++) {
226 auto reg_acc
227 = acc.xetla_select<acc_block_elems, 1>(j * acc_block_elems);
228 auto reg_acc_2d = reg_acc.xetla_format<dtype_acc, acc_size_y,
229 block_size_x>();
230 if (i == 0) {
231 out.xetla_select<block_size_x, 1>(j * block_size_x)
232 = reg_acc_2d.row(i);
233 } else {
234 out.xetla_select<block_size_x, 1>(j * block_size_x)
235 = reduce_helper<reduce_kind, dtype_acc, block_size_x>(
236 out.xetla_select<block_size_x, 1>(
237 j * block_size_x),
238 reg_acc_2d.row(i));
239 }
240 }
241 }
242
243 return xetla_cvt<dtype_out, dtype_acc, tile_size_x>(out);
244}
245
257template <typename T_dst, typename T_src, bool accumulate = true,
258 typename dtype_acc = float, uint32_t num_acc = 4>
260 "This is only for reduce add, and will be deprecated in future. "
261 "Please use tile_reduce instead.")
263 typename std::enable_if_t<(T_dst::register_layout == reg_layout::tiled)
264 && (T_src::register_layout == reg_layout::tiled)
265 && (T_dst::tile_size_x == T_src::tile_size_x)
266 && (T_dst::tile_size_y == 1)> tile_row_reduce(T_dst &dst,
267 T_src &src) {
268 static constexpr uint32_t tile_size_y = T_src::tile_size_y;
269 static constexpr uint32_t tile_size_x = T_src::tile_size_x;
270 static constexpr uint32_t block_size_y = T_src::block_size_y;
271 static constexpr uint32_t block_size_x = T_src::block_size_x;
272 static constexpr uint32_t block_elems = T_src::block_elems;
273 static constexpr int32_t num_block_x = T_src::num_block_x;
274 using dtype_dst = typename T_dst::dtype;
275 using dtype_src = typename T_src::dtype;
277 static constexpr uint32_t SIMD = 64 / sizeof(dtype_acc);
278 static constexpr uint32_t accum_len
279 = ((block_size_x % SIMD) && (sizeof(dtype_src) < 4)) == 0
280 ? SIMD
281 : block_size_x;
289 auto acc_2d = acc.xetla_format<dtype_acc, num_acc, tile_size_x>();
290 if constexpr (accumulate) {
291 acc_2d.row(0) = xetla_cvt<dtype_acc, dtype_dst, tile_size_x>(dst.reg);
292 }
293#pragma unroll
294 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
295#pragma unroll
296 for (uint32_t j = 0; j < num_block_x; j++) {
297 auto src_reg = (src.reg).xetla_select<block_elems, 1>(
298 (i * num_block_x + j) * block_elems);
299 auto src_reg_dtype_acc
300 = xetla_cvt<dtype_acc, dtype_src, block_elems>(src_reg);
301 auto src_reg_2d = src_reg_dtype_acc.xetla_format<dtype_acc,
302 block_size_y, block_size_x>();
303#pragma unroll
304 for (uint32_t row_i = 0; row_i < block_size_y; row_i += num_acc) {
305#pragma unroll
306 for (uint32_t acc_i = 0;
307 (acc_i < num_acc) && (row_i + acc_i < block_size_y);
308 acc_i++) {
309 auto acc_sub = acc.xetla_select<block_size_x, 1>(
310 acc_i * tile_size_x + j * block_size_x);
311#pragma unroll
312 for (uint32_t k = 0; k < block_size_x / accum_len; k++) {
313 acc_sub.xetla_select<accum_len, 1>(k * accum_len)
314 = acc_sub.xetla_select<accum_len, 1>(
315 k * accum_len)
316 + src_reg_2d.row(row_i + acc_i)
317 .xetla_select<accum_len, 1>(
318 k * accum_len);
319 }
320 }
321 }
322 }
323 }
324
325 // process the tail
326 if constexpr ((tile_size_y % block_size_y) != 0) {
327 constexpr uint32_t tail_start_y
328 = tile_size_y / block_size_y * block_size_y;
329 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
330 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
331#pragma unroll
332 for (uint32_t j = 0; j < num_block_x; j++) {
333 auto src_reg = (src.reg).xetla_select<tail_block_elems, 1>(
334 tail_start_y * tile_size_x + j * tail_block_elems);
335 auto src_reg_dtype_acc
336 = xetla_cvt<dtype_acc, dtype_src, tail_block_elems>(
337 src_reg);
338 auto src_reg_2d = src_reg_dtype_acc.xetla_format<dtype_acc,
339 tail_size_y, block_size_x>();
340#pragma unroll
341 for (uint32_t row_i = 0; row_i < tail_size_y; row_i += num_acc) {
342#pragma unroll
343 for (uint32_t acc_i = 0;
344 (acc_i < num_acc) && (row_i + acc_i < tail_size_y);
345 acc_i++) {
346 auto acc_sub = acc.xetla_select<block_size_x, 1>(
347 acc_i * tile_size_x + j * block_size_x);
348#pragma unroll
349 for (uint32_t k = 0; k < block_size_x / accum_len; k++) {
350 acc_sub.xetla_select<accum_len, 1>(k * accum_len)
351 = acc_sub.xetla_select<accum_len, 1>(
352 k * accum_len)
353 + src_reg_2d.row(row_i + acc_i)
354 .xetla_select<accum_len, 1>(
355 k * accum_len);
356 }
357 }
358 }
359 }
360 }
361
362#pragma unroll
363 for (uint32_t i = 1; i < num_acc; i++) {
364 acc_2d.row(0) += acc_2d.row(i);
365 }
366
367 dst.reg = xetla_cvt<dtype_dst, dtype_acc, tile_size_x>(acc_2d.row(0));
368}
369
370} // namespace gpu::xetla::subgroup
#define XETLA_MARKER(message)
Definition common.hpp:53
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__XETLA_API std::enable_if_t< N_x==1, xetla_vector< dtype, N_y > > recur_col_reduce(xetla_vector< dtype, N_y > in)
Definition misc.hpp:162
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
Definition limitation.hpp:457
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
reduce_op
xetla reduce op
Definition common.hpp:217
C++ API.
C++ API.