XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
kernel_func.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16#pragma once
17
18#include "xetla.hpp"
19
20using namespace gpu::xetla;
21using namespace gpu::xetla::group;
22using namespace gpu::xetla::subgroup;
23
25public:
26 using dtype_in = bf16;
27 using dtype_acc = float;
29 static constexpr uint32_t layer_size = 3;
31 static constexpr uint32_t sequence_length = 2;
33 static constexpr uint32_t batch_size = 512;
35 static constexpr uint32_t input_size = 384;
37 static constexpr uint32_t hidden_size = 704;
39 static constexpr uint32_t wg_tile_m = 64;
40 static constexpr uint32_t wg_tile_n = 128;
41 static constexpr uint32_t sg_tile_m = 16;
42 static constexpr uint32_t sg_tile_n = 16;
43 static constexpr uint32_t sg_tile_k = 32;
44};
45
46template <typename T>
48 uint32_t input_size;
49 uint32_t hidden_size;
50 uint32_t batch_size;
51 uint32_t sequence_length = 1;
53 = nullptr;
54 T *hx_ptr = nullptr;
55 T *W_ir_ptr = nullptr;
56 T *W_hr_ptr = nullptr;
57 T *W_iz_ptr = nullptr;
58 T *W_hz_ptr = nullptr;
59 T *W_in_ptr = nullptr;
60 T *W_hn_ptr = nullptr;
62 = nullptr;
64 = nullptr;
65 T *one_cell_ptr = nullptr;
66};
67
68#define CONFIG_SETTING(m, k, n) \
69 boundary_n = (start_n + wg_tile_n) > n ? n : (start_n + wg_tile_n); \
70 matrix_n = n; \
71 start_x_b = start_n; \
72 start_y_b = start_k;
73
74#define GEMM_CALL(id, acc_id, ptr_a, ptr_b) \
75 mem_desc_a.init({ptr_a}, \
76 {boundary_k_##id, boundary_m, \
77 is_col_major_a ? matrix_m : matrix_k_##id}, \
78 {start_x_a, start_y_a}); \
79 mem_desc_b.init({ptr_b}, \
80 {boundary_n, boundary_k_##id, \
81 is_col_major_b ? matrix_k_##id : matrix_n}, \
82 {start_x_b, start_y_b}); \
83 gemm_args.init(mem_desc_a, mem_desc_b, inner_loop_count_##id); \
84 op(g, matAcc_##acc_id, gemm_args); \
85 SW_BARRIER();
86
87#define MATC_STORE(ptr_c) \
88 mem_desc_c.init( \
89 {ptr_c}, {boundary_n, boundary_m, matrix_n}, {start_n, start_m}); \
90 epilogue(g, matAcc_0, mem_desc_c, epilogue_args);
91
92template <typename T, typename Act_T, uint32_t wg_tile_m, uint32_t wg_tile_n,
93 uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k,
94 mem_layout layout_input = mem_layout::row_major,
95 mem_layout layout_weight = mem_layout::row_major,
96 mem_layout layout_out = mem_layout::row_major,
97 mem_space mem_loc_input = mem_space::global,
98 mem_space mem_loc_weight = mem_space::global,
99 mem_space mem_loc_out = mem_space::global,
100 uint32_t periodic_sync_interval = 0>
101struct gru_layer {
102 static constexpr uint32_t prefetch_distance = 3;
105
108 perf_tuning_knob, gpu_arch::Xe>;
111 // Org the compute shape for sub-matrix
112 using tile_shape = tile_shape_t<wg_tile_n, // workgroup size in N dim
113 wg_tile_m, // workgroup size in M dim
114 sg_tile_n, // subgroup size in N dim
115 sg_tile_m>; // subgroup size in M dim
116
117 static constexpr bool is_col_major_a
118 = layout_input == mem_layout::col_major;
119 static constexpr bool is_col_major_b
120 = layout_weight == mem_layout::col_major;
123 using work_group_t = typename gemm_op::work_group_t;
124 using gemm_arguments = typename gemm_op::arguments_t;
125 using matAcc_t = typename gemm_op::matAcc_t;
126
128
129 // define arguments for each epilogue_tile_op in chained_tile_op_t<>
130
133 using epilogue_args_t = typename epilogue_t::arguments_t;
134
135 using matC_tile_desc_t = tile_desc_t<matAcc_t::tile_size_x,
136 matAcc_t::tile_size_y, matAcc_t::block_size_x,
137 matAcc_t::block_size_y, reg_layout::tiled>;
141 msg_type_v<matC_tile_desc_t, mem_loc_input>, gpu_arch::Xe>;
143 msg_type::block_2d, gpu_arch::Xe>;
145 using tanh_t = typename subgroup::tanh_op_t;
146 static void inline call(sycl::nd_item<3> &item, fused_config_t<T> *args) {
147 gemm_op op;
149 tanh_t tanh;
150 // declare two accumulators to stroe the results of two GEMMs
151 // and its activation
152 matAcc_t matAcc_0, matAcc_1;
153 gemm_arguments gemm_args;
154 mat_hidden_t mat_hidden;
155 mat_hidden_payload_t mat_hidden_payload;
156 mem_desc_a_t mem_desc_a;
157 mem_desc_b_t mem_desc_b;
158 mem_desc_c_t mem_desc_c;
159 epilogue_t epilogue;
160 epilogue_args_t epilogue_args {};
161
162 uint32_t batch_size, input_size, hidden_size, seq_len;
163 batch_size = args->batch_size;
164 input_size = args->input_size;
165 hidden_size = args->hidden_size;
166 seq_len = args->sequence_length;
167
168 uint32_t matrix_n = hidden_size;
169 uint32_t matrix_m = batch_size;
170 uint32_t matrix_k_0 = input_size;
171 uint32_t matrix_k_1 = hidden_size;
172 int start_x_b, start_y_b, start_x_a, start_y_a;
173 uint32_t boundary_n, boundary_m, boundary_k_0, boundary_k_1;
174 uint32_t wg_tile_k_0, wg_tile_k_1;
175 wg_tile_k_0 = input_size;
176 wg_tile_k_1 = hidden_size;
177 boundary_k_0 = wg_tile_k_0;
178 boundary_k_1 = wg_tile_k_1;
179
180 // layer_0:
181 // hidden out matrix = 512 x 704
182 // matmul input 512 x 384 : 384 x 704
183 // matmul hidden 512 x 704 : 704 x 704
184
185 // layer_1 , layar_2:
186 // hidden out matrix = 512 x 704
187 // matmul input 512 x 704 : 704 x 704
188 // matmul hidden 512 x 704 : 704 x 704
189 // two GEMMs will have different loop counts on k dim
190 uint32_t inner_loop_count_0 = (wg_tile_k_0 + sg_tile_k - 1) / sg_tile_k;
191 uint32_t inner_loop_count_1 = (wg_tile_k_1 + sg_tile_k - 1) / sg_tile_k;
192
193 int start_m = item.get_group(1) * wg_tile_m;
194
195 boundary_m = (start_m + wg_tile_m) > batch_size ? batch_size
196 : (start_m + wg_tile_m);
197
198 int start_k = 0;
199
200 start_x_a = start_k;
201 start_y_a = start_m;
202 int io_size = batch_size * hidden_size;
203 int pre_layer_size = batch_size * input_size;
204 work_group_t g(item.get_local_linear_id());
205 for (uint32_t seq_id = 0; seq_id < seq_len; ++seq_id) {
206 for (int j = (hidden_size + wg_tile_n - 1) / wg_tile_n - 1; j >= 0;
207 j--) {
208 int start_n = (j)*wg_tile_n;
209 CONFIG_SETTING(batch_size, -1, hidden_size);
210 matAcc_0.init(0);
211 SW_BARRIER();
212
213 // calculate reset gate: r_t = \sigmoid(X_t x W_ir + h_{t - 1} x W_hr)
214 // acc0 = X_t x W_ir
215 // acc0 += h_{t - 1} x W_hr
216 // acc0 = sigmoid(acc0)
217 // Mathematically elemwise_op is a map that applies to each element:
218 // elemwise_op: [m, n] -> [m, n], acc |-> tile_op_t(acc)
219 GEMM_CALL(0, 0, args->layer_ptr + seq_id * pre_layer_size,
220 args->W_ir_ptr);
221 GEMM_CALL(1, 0, args->hx_ptr, args->W_hr_ptr);
222 sigmoid(matAcc_0, 0);
223 // calculate new gate : n_t = tanh(X_t x W_in + r_t * (h_{t - 1} x
224 // W_hn)) acc1 = h_{t - 1} x W_hn acc0 *= acc1 acc0 += X_t x W_in acc0 =
225 // tanh(acc0) Mathematically elemwise_op is a map that applies to each
226 // element:
227 // elemwise_op: [m, n] -> [m, n], acc |-> tile_op_t(acc)
228 matAcc_1.init(0);
229 GEMM_CALL(1, 1, args->hx_ptr, args->W_hn_ptr);
230 matAcc_0.reg = matAcc_1.reg * matAcc_0.reg;
231 GEMM_CALL(0, 0, args->layer_ptr + seq_id * pre_layer_size,
232 args->W_in_ptr);
233
234 tanh(matAcc_0, 0);
235 // calculate input gate z_t = \sigma(X_t x W_iz + h_{t - 1} x W_hz)
236 // acc1 = X_t x W_iz
237 // acc1 += h_{t - 1} x W_hz
238 // acc1 = sigmoid(acc1)
239 // Mathematically elemwise_op is a map that applies to each element:
240 // elemwise_op: [m, n] -> [m, n], acc |-> tile_op_t(acc)
241 matAcc_1.init(0);
242 GEMM_CALL(1, 1, args->hx_ptr, args->W_hz_ptr);
243 GEMM_CALL(0, 1, args->layer_ptr + seq_id * pre_layer_size,
244 args->W_iz_ptr);
245 sigmoid(matAcc_1, 0);
246 // calculate h_t = (1 - z_t) n_t + z_t h_{t - 1} NOTICE z_t in Acc1,
247 // n_t in Acc0 reload h_{t - 1}
248 // acc0 = acc0 * (1 - acc1) + acc1 * h_{t -1}
249 mem_desc_c.init({args->hx_ptr},
250 {boundary_n, boundary_m, matrix_n},
251 {start_n + gemm_op::get_matC_offset_x(g),
252 start_m + gemm_op::get_matC_offset_y(g)});
253 mat_hidden_payload.init(mem_desc_c);
254 tile_load<cache_hint::cached, cache_hint::cached>(
255 mat_hidden, mat_hidden_payload);
256 matAcc_0.reg = matAcc_0.reg * (1 - matAcc_1.reg)
257 + matAcc_1.reg
258 * xetla_cvt<Act_T, T, matAcc_t::tile_elems>(
259 mat_hidden.reg);
260 SW_BARRIER();
261
262 if (seq_id == seq_len - 1) {
264 SW_BARRIER();
265 __esimd_barrier();
266 }
267 MATC_STORE(args->cell_out_ptr + seq_id * io_size);
268 SW_BARRIER();
269 __esimd_barrier();
270
271 MATC_STORE(args->one_cell_ptr + (seq_id % 2) * io_size);
272 SW_BARRIER();
273 __esimd_barrier();
274 }
275 args->hx_ptr = args->one_cell_ptr + (seq_id % 2) * io_size;
276 }
277 }
278};
279
280template <typename input_T, typename Act_T, uint32_t wg_tile_m_t,
281 uint32_t wg_tile_n_t, uint32_t sg_tile_m_t, uint32_t sg_tile_n_t,
282 uint32_t sg_tile_k_t>
305 static void inline run(sycl::nd_item<3> &item, input_T *layer_ptr,
306 input_T *h0_ptr, input_T *W_ir_ptr, input_T *W_hr_ptr,
307 input_T *W_iz_ptr, input_T *W_hz_ptr, input_T *W_in_ptr,
308 input_T *W_hn_ptr, input_T *layer_out_ptr, input_T *hidden_out_ptr,
309 input_T *ping_pong_buffer, input_T *ping_pong_cell, int batch_size,
310 int input_size, int hidden_size, int sequence_length,
311 int layer_size) {
312 constexpr uint32_t fused_op_wg_m = wg_tile_m_t;
313 constexpr uint32_t fused_op_wg_n = wg_tile_n_t;
314 constexpr uint32_t fused_op_sg_m = sg_tile_m_t;
315 constexpr uint32_t fused_op_sg_n = sg_tile_n_t;
316 constexpr uint32_t fused_op_sg_k = sg_tile_k_t;
317
318 using fused_op = gru_layer<input_T, Act_T, fused_op_wg_m, fused_op_wg_n,
319 fused_op_sg_m, fused_op_sg_n, fused_op_sg_k>;
320
322 int hidden_io_size = batch_size * hidden_size;
323 int input_weight_size = input_size * hidden_size;
324 int hidden_weight_size = hidden_size * hidden_size;
325 int one_layer_size = sequence_length * batch_size * hidden_size;
326 int ping = 0;
327 int pong = 1;
328 args.one_cell_ptr = ping_pong_cell;
329 args.input_size = input_size;
330 args.batch_size = batch_size;
331 args.hidden_size = hidden_size;
332 args.sequence_length = sequence_length;
333 args.cell_out_ptr = layer_size == 1
334 ? hidden_out_ptr
335 : (ping_pong_buffer + ping * one_layer_size);
336 args.layer_ptr = (layer_ptr);
337 args.hx_ptr = (h0_ptr);
338 args.layer_output = layer_out_ptr;
339 args.W_ir_ptr = (W_ir_ptr);
340 args.W_hr_ptr = (W_hr_ptr);
341 args.W_iz_ptr = (W_iz_ptr);
342 args.W_hz_ptr = (W_hz_ptr);
343 args.W_in_ptr = (W_in_ptr);
344 args.W_hn_ptr = (W_hn_ptr);
345 SW_BARRIER();
346 fused_op::call(item, &args);
347 ping = (ping + 1) % 2;
348 pong = (pong + 1) % 2;
349
350 args.input_size = hidden_size;
351 args.batch_size = batch_size;
352 args.hidden_size = hidden_size;
353 uint32_t current_layer_size = layer_size;
354 for (uint32_t layer_id = 1; layer_id < current_layer_size; ++layer_id) {
355 args.layer_output = layer_out_ptr + layer_id * hidden_io_size;
356 args.hx_ptr = (h0_ptr + layer_id * hidden_io_size);
357 args.W_ir_ptr = (W_ir_ptr + (layer_id - 1) * hidden_weight_size
358 + input_weight_size);
359 args.W_hr_ptr = (W_hr_ptr + layer_id * hidden_weight_size);
360 args.W_iz_ptr = (W_iz_ptr + (layer_id - 1) * hidden_weight_size
361 + input_weight_size);
362 args.W_hz_ptr = (W_hz_ptr + layer_id * hidden_weight_size);
363 args.W_in_ptr = (W_in_ptr + (layer_id - 1) * hidden_weight_size
364 + input_weight_size);
365 args.W_hn_ptr = (W_hn_ptr + layer_id * hidden_weight_size);
366 args.cell_out_ptr = layer_id == current_layer_size - 1
367 ? hidden_out_ptr
368 : (ping_pong_buffer + ping * one_layer_size);
369 args.layer_ptr = ((ping_pong_buffer + pong * one_layer_size));
370 SW_BARRIER();
371 fused_op::call(item, &args);
372 ping = (ping + 1) % 2;
373 pong = (pong + 1) % 2;
374 }
375 }
376};
Gemm functor.
Definition api.hpp:52
Definition kernel_func.hpp:24
bf16 dtype_in
Definition kernel_func.hpp:26
static constexpr uint32_t sequence_length
sequence_length = 64
Definition kernel_func.hpp:31
static constexpr uint32_t layer_size
layer_size = 3
Definition kernel_func.hpp:29
static constexpr uint32_t sg_tile_k
Definition kernel_func.hpp:43
static constexpr uint32_t hidden_size
hidden_size = 688;
Definition kernel_func.hpp:37
static constexpr uint32_t wg_tile_n
Definition kernel_func.hpp:40
static constexpr uint32_t sg_tile_m
Definition kernel_func.hpp:41
static constexpr uint32_t sg_tile_n
Definition kernel_func.hpp:42
float dtype_acc
Definition kernel_func.hpp:27
static constexpr uint32_t wg_tile_m
launch config
Definition kernel_func.hpp:39
static constexpr uint32_t input_size
input_size = 384
Definition kernel_func.hpp:35
static constexpr uint32_t batch_size
batch_size = 512
Definition kernel_func.hpp:33
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
sycl::ext::oneapi::bfloat16 bf16
xetla bf16 data type.
Definition base_types.hpp:40
#define GEMM_CALL(id, acc_id, ptr_a, ptr_b)
Definition kernel_func.hpp:74
#define CONFIG_SETTING(m, k, n)
Definition kernel_func.hpp:68
#define MATC_STORE(ptr_c)
Definition kernel_func.hpp:87
Definition limitation.hpp:607
Definition limitation.hpp:457
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
mem_layout
Definition common.hpp:76
Definition kernel_func.hpp:47
uint32_t batch_size
Definition kernel_func.hpp:50
T * W_hr_ptr
Definition kernel_func.hpp:56
T * cell_out_ptr
Definition kernel_func.hpp:62
T * W_hn_ptr
Definition kernel_func.hpp:60
uint32_t hidden_size
Definition kernel_func.hpp:49
T * W_iz_ptr
Definition kernel_func.hpp:57
T * layer_output
cell output = sequence_length x batch_size x hidden_size
Definition kernel_func.hpp:64
T * W_hz_ptr
Definition kernel_func.hpp:58
T * hx_ptr
layer_input = sequence_length x batch_size x input_size
Definition kernel_func.hpp:54
T * one_cell_ptr
layer_output = layer_size x batch_size x hidden_size
Definition kernel_func.hpp:65
uint32_t sequence_length
Definition kernel_func.hpp:51
uint32_t input_size
Definition kernel_func.hpp:48
T * W_ir_ptr
h_x input = batch_size x hidden_size
Definition kernel_func.hpp:55
T * W_in_ptr
Definition kernel_func.hpp:59
T * layer_ptr
Definition kernel_func.hpp:53
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is the element-wise sigmoid op functor.
Definition tile_op_functor.hpp:89
Is the element-wise tanh op functor.
Definition tile_op_functor.hpp:62
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
Definition kernel_func.hpp:101
typename subgroup::sigmoid_op_t sigmoid_t
Definition kernel_func.hpp:144
typename epilogue_t::arguments_t epilogue_args_t
Definition kernel_func.hpp:133
static constexpr bool is_col_major_b
Definition kernel_func.hpp:120
tile_shape_t< wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m > tile_shape
Definition kernel_func.hpp:115
epilogue_t< epilogue_policy_default< gpu_arch::Xe >, tile_shape, mem_desc_c_t > epilogue_t
Definition kernel_func.hpp:132
static void call(sycl::nd_item< 3 > &item, fused_config_t< T > *args)
Definition kernel_func.hpp:146
static constexpr uint32_t prefetch_distance
Definition kernel_func.hpp:102
mem_desc_t< T, layout_out, mem_loc_out > mem_desc_c_t
Definition kernel_func.hpp:127
mem_desc_t< T, layout_input, mem_loc_input > mem_desc_a_t
Definition kernel_func.hpp:109
typename subgroup::tanh_op_t tanh_t
Definition kernel_func.hpp:145
group::compute_attr_t< T, T, Act_T > compute_attr
Definition kernel_func.hpp:106
typename gemm_op::matAcc_t matAcc_t
Definition kernel_func.hpp:125
perf_tuning_knob_t< sg_tile_k, prefetch_distance, periodic_sync_interval > perf_tuning_knob
Definition kernel_func.hpp:104
tile_desc_t< matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled > matC_tile_desc_t
Definition kernel_func.hpp:137
typename gemm_op::arguments_t gemm_arguments
Definition kernel_func.hpp:124
static constexpr bool is_col_major_a
Definition kernel_func.hpp:118
typename gemm_op::work_group_t work_group_t
Definition kernel_func.hpp:123
Definition kernel_func.hpp:283
static void run(sycl::nd_item< 3 > &item, input_T *layer_ptr, input_T *h0_ptr, input_T *W_ir_ptr, input_T *W_hr_ptr, input_T *W_iz_ptr, input_T *W_hz_ptr, input_T *W_in_ptr, input_T *W_hn_ptr, input_T *layer_out_ptr, input_T *hidden_out_ptr, input_T *ping_pong_buffer, input_T *ping_pong_cell, int batch_size, int input_size, int hidden_size, int sequence_length, int layer_size)
Definition kernel_func.hpp:305
C++ API.