68#define CONFIG_SETTING(m, k, n) \
69 boundary_n = (start_n + wg_tile_n) > n ? n : (start_n + wg_tile_n); \
71 start_x_b = start_n; \
74#define GEMM_CALL(id, acc_id, ptr_a, ptr_b) \
75 mem_desc_a.init({ptr_a}, \
76 {boundary_k_##id, boundary_m, \
77 is_col_major_a ? matrix_m : matrix_k_##id}, \
78 {start_x_a, start_y_a}); \
79 mem_desc_b.init({ptr_b}, \
80 {boundary_n, boundary_k_##id, \
81 is_col_major_b ? matrix_k_##id : matrix_n}, \
82 {start_x_b, start_y_b}); \
83 gemm_args.init(mem_desc_a, mem_desc_b, inner_loop_count_##id); \
84 op(g, matAcc_##acc_id, gemm_args); \
87#define MATC_STORE(ptr_c) \
89 {ptr_c}, {boundary_n, boundary_m, matrix_n}, {start_n, start_m}); \
90 epilogue(g, matAcc_0, mem_desc_c, epilogue_args);
92template <
typename T,
typename Act_T, uint32_t wg_tile_m, uint32_t wg_tile_n,
93 uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k,
94 mem_layout layout_input = mem_layout::row_major,
95 mem_layout layout_weight = mem_layout::row_major,
97 mem_space mem_loc_input = mem_space::global,
98 mem_space mem_loc_weight = mem_space::global,
99 mem_space mem_loc_out = mem_space::global,
118 = layout_input == mem_layout::col_major;
120 = layout_weight == mem_layout::col_major;
136 matAcc_t::tile_size_y, matAcc_t::block_size_x,
137 matAcc_t::block_size_y, reg_layout::tiled>;
141 msg_type_v<matC_tile_desc_t, mem_loc_input>, gpu_arch::Xe>;
143 msg_type::block_2d, gpu_arch::Xe>;
162 uint32_t batch_size, input_size, hidden_size, seq_len;
168 uint32_t matrix_n = hidden_size;
169 uint32_t matrix_m = batch_size;
170 uint32_t matrix_k_0 = input_size;
171 uint32_t matrix_k_1 = hidden_size;
172 int start_x_b, start_y_b, start_x_a, start_y_a;
173 uint32_t boundary_n, boundary_m, boundary_k_0, boundary_k_1;
174 uint32_t wg_tile_k_0, wg_tile_k_1;
175 wg_tile_k_0 = input_size;
176 wg_tile_k_1 = hidden_size;
177 boundary_k_0 = wg_tile_k_0;
178 boundary_k_1 = wg_tile_k_1;
190 uint32_t inner_loop_count_0 = (wg_tile_k_0 + sg_tile_k - 1) / sg_tile_k;
191 uint32_t inner_loop_count_1 = (wg_tile_k_1 + sg_tile_k - 1) / sg_tile_k;
193 int start_m = item.get_group(1) * wg_tile_m;
195 boundary_m = (start_m + wg_tile_m) > batch_size ? batch_size
196 : (start_m + wg_tile_m);
202 int io_size = batch_size * hidden_size;
203 int pre_layer_size = batch_size * input_size;
205 for (uint32_t seq_id = 0; seq_id < seq_len; ++seq_id) {
206 for (
int j = (hidden_size + wg_tile_n - 1) / wg_tile_n - 1; j >= 0;
208 int start_n = (j)*wg_tile_n;
230 matAcc_0.reg = matAcc_1.reg * matAcc_0.reg;
249 mem_desc_c.init({args->
hx_ptr},
250 {boundary_n, boundary_m, matrix_n},
251 {start_n + gemm_op::get_matC_offset_x(g),
252 start_m + gemm_op::get_matC_offset_y(g)});
253 mat_hidden_payload.init(mem_desc_c);
254 tile_load<cache_hint::cached, cache_hint::cached>(
255 mat_hidden, mat_hidden_payload);
256 matAcc_0.reg = matAcc_0.reg * (1 - matAcc_1.reg)
258 * xetla_cvt<Act_T, T, matAcc_t::tile_elems>(
262 if (seq_id == seq_len - 1) {
280template <
typename input_T,
typename Act_T, uint32_t wg_tile_m_t,
281 uint32_t wg_tile_n_t, uint32_t sg_tile_m_t, uint32_t sg_tile_n_t,
282 uint32_t sg_tile_k_t>
305 static void inline run(sycl::nd_item<3> &item, input_T *layer_ptr,
306 input_T *h0_ptr, input_T *W_ir_ptr, input_T *W_hr_ptr,
307 input_T *W_iz_ptr, input_T *W_hz_ptr, input_T *W_in_ptr,
308 input_T *W_hn_ptr, input_T *layer_out_ptr, input_T *hidden_out_ptr,
309 input_T *ping_pong_buffer, input_T *ping_pong_cell,
int batch_size,
310 int input_size,
int hidden_size,
int sequence_length,
312 constexpr uint32_t fused_op_wg_m = wg_tile_m_t;
313 constexpr uint32_t fused_op_wg_n = wg_tile_n_t;
314 constexpr uint32_t fused_op_sg_m = sg_tile_m_t;
315 constexpr uint32_t fused_op_sg_n = sg_tile_n_t;
316 constexpr uint32_t fused_op_sg_k = sg_tile_k_t;
318 using fused_op =
gru_layer<input_T, Act_T, fused_op_wg_m, fused_op_wg_n,
319 fused_op_sg_m, fused_op_sg_n, fused_op_sg_k>;
322 int hidden_io_size = batch_size * hidden_size;
323 int input_weight_size = input_size * hidden_size;
324 int hidden_weight_size = hidden_size * hidden_size;
325 int one_layer_size = sequence_length * batch_size * hidden_size;
335 : (ping_pong_buffer + ping * one_layer_size);
346 fused_op::call(item, &args);
347 ping = (ping + 1) % 2;
348 pong = (pong + 1) % 2;
353 uint32_t current_layer_size = layer_size;
354 for (uint32_t layer_id = 1; layer_id < current_layer_size; ++layer_id) {
355 args.
layer_output = layer_out_ptr + layer_id * hidden_io_size;
356 args.
hx_ptr = (h0_ptr + layer_id * hidden_io_size);
357 args.
W_ir_ptr = (W_ir_ptr + (layer_id - 1) * hidden_weight_size
358 + input_weight_size);
359 args.
W_hr_ptr = (W_hr_ptr + layer_id * hidden_weight_size);
360 args.
W_iz_ptr = (W_iz_ptr + (layer_id - 1) * hidden_weight_size
361 + input_weight_size);
362 args.
W_hz_ptr = (W_hz_ptr + layer_id * hidden_weight_size);
363 args.
W_in_ptr = (W_in_ptr + (layer_id - 1) * hidden_weight_size
364 + input_weight_size);
365 args.
W_hn_ptr = (W_hn_ptr + layer_id * hidden_weight_size);
368 : (ping_pong_buffer + ping * one_layer_size);
369 args.
layer_ptr = ((ping_pong_buffer + pong * one_layer_size));
371 fused_op::call(item, &args);
372 ping = (ping + 1) % 2;
373 pong = (pong + 1) % 2;
Gemm functor.
Definition api.hpp:52
Definition kernel_func.hpp:24
bf16 dtype_in
Definition kernel_func.hpp:26
static constexpr uint32_t sequence_length
sequence_length = 64
Definition kernel_func.hpp:31
static constexpr uint32_t layer_size
layer_size = 3
Definition kernel_func.hpp:29
static constexpr uint32_t sg_tile_k
Definition kernel_func.hpp:43
static constexpr uint32_t hidden_size
hidden_size = 688;
Definition kernel_func.hpp:37
static constexpr uint32_t wg_tile_n
Definition kernel_func.hpp:40
static constexpr uint32_t sg_tile_m
Definition kernel_func.hpp:41
static constexpr uint32_t sg_tile_n
Definition kernel_func.hpp:42
float dtype_acc
Definition kernel_func.hpp:27
static constexpr uint32_t wg_tile_m
launch config
Definition kernel_func.hpp:39
static constexpr uint32_t input_size
input_size = 384
Definition kernel_func.hpp:35
static constexpr uint32_t batch_size
batch_size = 512
Definition kernel_func.hpp:33
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
sycl::ext::oneapi::bfloat16 bf16
xetla bf16 data type.
Definition base_types.hpp:40
#define GEMM_CALL(id, acc_id, ptr_a, ptr_b)
Definition kernel_func.hpp:74
#define CONFIG_SETTING(m, k, n)
Definition kernel_func.hpp:68
#define MATC_STORE(ptr_c)
Definition kernel_func.hpp:87
Definition limitation.hpp:607
Definition limitation.hpp:457
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
mem_layout
Definition common.hpp:76
Definition kernel_func.hpp:47
uint32_t batch_size
Definition kernel_func.hpp:50
T * W_hr_ptr
Definition kernel_func.hpp:56
T * cell_out_ptr
Definition kernel_func.hpp:62
T * W_hn_ptr
Definition kernel_func.hpp:60
uint32_t hidden_size
Definition kernel_func.hpp:49
T * W_iz_ptr
Definition kernel_func.hpp:57
T * layer_output
cell output = sequence_length x batch_size x hidden_size
Definition kernel_func.hpp:64
T * W_hz_ptr
Definition kernel_func.hpp:58
T * hx_ptr
layer_input = sequence_length x batch_size x input_size
Definition kernel_func.hpp:54
T * one_cell_ptr
layer_output = layer_size x batch_size x hidden_size
Definition kernel_func.hpp:65
uint32_t sequence_length
Definition kernel_func.hpp:51
uint32_t input_size
Definition kernel_func.hpp:48
T * W_ir_ptr
h_x input = batch_size x hidden_size
Definition kernel_func.hpp:55
T * W_in_ptr
Definition kernel_func.hpp:59
T * layer_ptr
Definition kernel_func.hpp:53
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Fine-tune knobs for gemm.
Definition common.hpp:43
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is the element-wise sigmoid op functor.
Definition tile_op_functor.hpp:89
Is the element-wise tanh op functor.
Definition tile_op_functor.hpp:62
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
Definition kernel_func.hpp:101
typename subgroup::sigmoid_op_t sigmoid_t
Definition kernel_func.hpp:144
typename epilogue_t::arguments_t epilogue_args_t
Definition kernel_func.hpp:133
static constexpr bool is_col_major_b
Definition kernel_func.hpp:120
tile_shape_t< wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m > tile_shape
Definition kernel_func.hpp:115
epilogue_t< epilogue_policy_default< gpu_arch::Xe >, tile_shape, mem_desc_c_t > epilogue_t
Definition kernel_func.hpp:132
static void call(sycl::nd_item< 3 > &item, fused_config_t< T > *args)
Definition kernel_func.hpp:146
static constexpr uint32_t prefetch_distance
Definition kernel_func.hpp:102
mem_desc_t< T, layout_out, mem_loc_out > mem_desc_c_t
Definition kernel_func.hpp:127
mem_desc_t< T, layout_input, mem_loc_input > mem_desc_a_t
Definition kernel_func.hpp:109
typename subgroup::tanh_op_t tanh_t
Definition kernel_func.hpp:145
group::compute_attr_t< T, T, Act_T > compute_attr
Definition kernel_func.hpp:106
typename gemm_op::matAcc_t matAcc_t
Definition kernel_func.hpp:125
perf_tuning_knob_t< sg_tile_k, prefetch_distance, periodic_sync_interval > perf_tuning_knob
Definition kernel_func.hpp:104
tile_desc_t< matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled > matC_tile_desc_t
Definition kernel_func.hpp:137
typename gemm_op::arguments_t gemm_arguments
Definition kernel_func.hpp:124
static constexpr bool is_col_major_a
Definition kernel_func.hpp:118
typename gemm_op::work_group_t work_group_t
Definition kernel_func.hpp:123
Definition kernel_func.hpp:283
static void run(sycl::nd_item< 3 > &item, input_T *layer_ptr, input_T *h0_ptr, input_T *W_ir_ptr, input_T *W_hr_ptr, input_T *W_iz_ptr, input_T *W_hz_ptr, input_T *W_in_ptr, input_T *W_hn_ptr, input_T *layer_out_ptr, input_T *hidden_out_ptr, input_T *ping_pong_buffer, input_T *ping_pong_cell, int batch_size, int input_size, int hidden_size, int sequence_length, int layer_size)
Definition kernel_func.hpp:305