26template <
typename gemm_layer1_t_,
typename epilogue_layer1_t_,
27 typename gemm_layer2_t_,
typename epilogue_layer2_t_,
30 using gemm_layer1_t = gemm_layer1_t_;
31 using epilogue_layer1_t = epilogue_layer1_t_;
32 using gemm_layer1_args_t =
typename gemm_layer1_t::arguments_t;
33 using epilogue_layer1_args_t =
typename epilogue_layer1_t::arguments_t;
35 using tile_shape_layer1 =
typename gemm_layer1_t::tile_shape;
36 static constexpr uint32_t wg_tile_m_layer1
37 = tile_shape_layer1::wg_tile_size_y;
38 static constexpr uint32_t wg_tile_n_layer1
39 = tile_shape_layer1::wg_tile_size_x;
40 static constexpr uint32_t sg_tile_m_layer1
41 = tile_shape_layer1::sg_tile_size_y;
42 static constexpr uint32_t sg_tile_n_layer1
43 = tile_shape_layer1::sg_tile_size_x;
44 static constexpr uint32_t wg_size_y_layer1 = tile_shape_layer1::wg_size_y;
45 static constexpr uint32_t wg_size_x_layer1 = tile_shape_layer1::wg_size_x;
46 static constexpr uint32_t real_wg_tile_m_layer1
47 = sg_tile_m_layer1 * wg_size_y_layer1;
48 static constexpr uint32_t real_wg_tile_n_layer1
49 = sg_tile_n_layer1 * wg_size_x_layer1;
51 static constexpr uint32_t k_stride_layer1 = gemm_layer1_t::k_stride;
52 using work_group_layer1_t =
typename gemm_layer1_t::work_group_t;
54 static constexpr gpu_arch arch_tag = arch_tag_;
56 arch_tag == gemm_layer1_t::arch_tag,
"arch_tag should be the same");
57 static_assert(arch_tag == epilogue_layer1_t::arch_tag,
58 "arch_tag should be the same");
59 static_assert(std::is_same<
typename gemm_layer1_t::tile_shape,
60 typename epilogue_layer1_t::tile_shape>::value,
61 "tile_shape should be the same");
63 using mem_desc_a_t =
typename gemm_layer1_t::mem_desc_a_t;
64 using mem_desc_w_t =
typename gemm_layer1_t::mem_desc_b_t;
65 using mem_desc_b_t =
typename epilogue_layer1_t::mem_desc_c_t;
66 using matA_base_t =
typename mem_desc_a_t::base_t;
67 using matW_base_t =
typename mem_desc_w_t::base_t;
68 using matB_base_t =
typename mem_desc_b_t::base_t;
69 using dtype_a =
typename mem_desc_a_t::dtype;
70 using dtype_w =
typename mem_desc_w_t::dtype;
71 using dtype_b =
typename mem_desc_b_t::dtype;
72 using matAcc_layer1_t =
typename gemm_layer1_t::matAcc_t;
74 using gemm_layer2_t = gemm_layer2_t_;
75 using epilogue_layer2_t = epilogue_layer2_t_;
76 using gemm_layer2_args_t =
typename gemm_layer2_t::arguments_t;
77 using epilogue_layer2_args_t =
typename epilogue_layer2_t::arguments_t;
79 using tile_shape_layer2 =
typename gemm_layer2_t::tile_shape;
80 static constexpr uint32_t wg_tile_m_layer2
81 = tile_shape_layer2::wg_tile_size_y;
82 static constexpr uint32_t wg_tile_n_layer2
83 = tile_shape_layer2::wg_tile_size_x;
84 static constexpr uint32_t sg_tile_m_layer2
85 = tile_shape_layer2::sg_tile_size_y;
86 static constexpr uint32_t sg_tile_n_layer2
87 = tile_shape_layer2::sg_tile_size_x;
88 static constexpr uint32_t wg_size_y_layer2 = tile_shape_layer2::wg_size_y;
89 static constexpr uint32_t wg_size_x_layer2 = tile_shape_layer2::wg_size_x;
90 static constexpr uint32_t real_wg_tile_m_layer2
91 = sg_tile_m_layer2 * wg_size_y_layer2;
92 static constexpr uint32_t real_wg_tile_n_layer2
93 = sg_tile_n_layer2 * wg_size_x_layer2;
95 static constexpr uint32_t k_stride_layer2 = gemm_layer2_t::k_stride;
96 using work_group_layer2_t =
typename gemm_layer2_t::work_group_t;
99 arch_tag == gemm_layer2_t::arch_tag,
"arch_tag should be the same");
100 static_assert(arch_tag == epilogue_layer2_t::arch_tag,
101 "arch_tag should be the same");
102 static_assert(std::is_same<
typename gemm_layer2_t::tile_shape,
103 typename epilogue_layer2_t::tile_shape>::value,
104 "tile_shape should be the same");
107 static_assert(std::is_same<
typename epilogue_layer1_t::mem_desc_c_t,
108 typename gemm_layer2_t::mem_desc_a_t>::value,
109 "the output of first gemm should be the left input og second "
111 using mem_desc_v_t =
typename gemm_layer2_t::mem_desc_b_t;
112 using mem_desc_c_t =
typename epilogue_layer2_t::mem_desc_c_t;
113 using matV_base_t =
typename mem_desc_v_t::base_t;
114 using matC_base_t =
typename mem_desc_c_t::base_t;
115 using dtype_v =
typename mem_desc_v_t::dtype;
116 using dtype_c =
typename mem_desc_c_t::dtype;
117 using matAcc_layer2_t =
typename gemm_layer2_t::matAcc_t;
186 inline arguments_t(uint32_t matrix_m_layer1_, uint32_t matrix_k_layer1_,
187 uint32_t matrix_n_layer1_, uint32_t matrix_m_layer2_,
188 uint32_t matrix_k_layer2_, uint32_t matrix_n_layer2_,
189 matA_base_t matA_base_, uint32_t matA_ld_,
190 matW_base_t matW_base_, uint32_t matW_ld_,
191 matB_base_t matB_base_, uint32_t matB_ld_,
192 matV_base_t matV_base_, uint32_t matV_ld_,
193 matC_base_t matC_base_, uint32_t matC_ld_,
194 epilogue_layer1_args_t epilogue_layer1_args_ = {},
195 epilogue_layer2_args_t epilogue_layer2_args_ = {})
263 constexpr uint32_t count = gemm_layer1_t::barrier_count
264 + epilogue_layer1_t::barrier_count + 1
265 > gemm_layer2_t::barrier_count
266 + epilogue_layer2_t::barrier_count
267 ? gemm_layer1_t::barrier_count
268 + epilogue_layer1_t::barrier_count + 1
269 : gemm_layer2_t::barrier_count
270 + epilogue_layer2_t::barrier_count;
272 count <= 32,
"The named_barrier count should be less than 32!");
289 static_assert(work_group_layer1_t::size == work_group_layer2_t::size,
290 "we should make sure first gemm and second gemm use same "
292 uint32_t local_range_m
293 = (wg_tile_m_layer2 + sg_tile_m_layer2 - 1) / sg_tile_m_layer2;
294 uint32_t local_range_n
295 = (wg_tile_n_layer2 + sg_tile_n_layer2 - 1) / sg_tile_n_layer2;
296 std::cout <<
"Local range: {" << 1 <<
", " << local_range_m <<
", "
297 << local_range_n <<
"} \n";
298 assert(local_range_m * local_range_n <= 32);
299 return cl::sycl::range<3> {1, local_range_m, local_range_n};
308 static_assert(wg_tile_m_layer1 == wg_tile_m_layer2,
309 "first gemm and second gemm should have the same wg_tile_m");
321 std::cout <<
"Group range: {1"
322 <<
", " << group_range_m <<
", " << group_range_n <<
"} \n";
323 return cl::sycl::range<3> {1, group_range_m, group_range_n};
332 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
339 bool implementable =
true;
343 dtype_a>::check_tensor((uint64_t)(args.
matA_base.base),
348 dtype_a>::check_alignment(args.
matA_base.base,
355 dtype_w>::check_tensor((uint64_t)(args.
matW_base.base),
360 dtype_w>::check_alignment(args.
matW_base.base,
367 dtype_b>::check_tensor((uint64_t)(args.
matB_base.base),
372 dtype_b>::check_alignment(args.
matB_base.base,
379 dtype_b>::check_tensor((uint64_t)(args.
matB_base.base),
384 dtype_a>::check_alignment(args.
matB_base.base,
391 dtype_v>::check_tensor((uint64_t)(args.
matV_base.base),
396 dtype_v>::check_alignment(args.
matV_base.base,
403 dtype_c>::check_tensor((uint64_t)(args.
matC_base.base),
408 dtype_c>::check_alignment(args.
matC_base.base,
413 return implementable;
424 uint32_t nbarrier_base = 0) {
426 int start_n = item.get_group(2) * wg_tile_n_layer1;
427 int start_m = item.get_group(1) * wg_tile_m_layer1;
433 : (start_n + wg_tile_n_layer1);
437 : (start_m + wg_tile_m_layer1);
440 uint32_t gemm_layer1_nbarr_base = nbarrier_base;
441 uint32_t epilogue_layer1_nbarr_base
442 = gemm_layer1_nbarr_base + gemm_layer1_t::barrier_count;
443 uint32_t global_nbarr_base
444 = epilogue_layer1_nbarr_base + epilogue_layer1_t::barrier_count;
446 uint32_t gemm_layer2_nbarr_base = nbarrier_base;
447 uint32_t epilogue_layer2_nbarr_base
448 = gemm_layer2_nbarr_base + gemm_layer2_t::barrier_count;
450 uint32_t gemm_layer1_slm_base = slm_base;
451 uint32_t epilogue_layer1_slm_base
452 = gemm_layer1_slm_base + gemm_layer1_t::slm_size;
453 uint32_t gemm_layer2_slm_base
454 = epilogue_layer1_slm_base + epilogue_layer2_t::slm_size;
455 uint32_t epilogue_layer2_slm_base
456 = gemm_layer2_slm_base + gemm_layer2_t::slm_size;
459 work_group_layer1_t g_layer1;
460 g_layer1.init(item.get_local_linear_id());
461 mem_desc_a_t mem_desc_a;
462 mem_desc_w_t mem_desc_w;
463 mem_desc_b_t mem_desc_b;
468 mem_desc_w.init(args.matW_base, {boundary_n, boundary_k, args.matW_ld},
471 mem_desc_b.init(args.matB_base, {boundary_n, boundary_m, args.matB_ld},
474 uint32_t inner_loop_count
475 = (
wg_tile_k + k_stride_layer1 - 1) / k_stride_layer1;
476 gemm_layer1_args_t gemm_layer1_args(
477 mem_desc_a, mem_desc_w, inner_loop_count);
478 gemm_layer1_t gemm_layer1;
479 epilogue_layer1_t epilogue_layer1;
481 matAcc_layer1_t matAcc_layer1(0);
482 gemm_layer1(g_layer1, matAcc_layer1, gemm_layer1_args,
483 gemm_layer1_slm_base, gemm_layer1_nbarr_base);
484 epilogue_layer1(g_layer1, matAcc_layer1, mem_desc_b,
485 args.epilogue_layer1_args, epilogue_layer1_slm_base,
486 epilogue_layer1_nbarr_base);
490 xetla_nbarrier_t<work_group_layer2_t::size, work_group_layer2_t::size,
493 nbarrier_global.init_nbarrier(
494 global_nbarr_base, nbarrier_role::producer_consumer);
495 nbarrier_global.arrive_wait();
498 start_n = item.get_group(2) * wg_tile_n_layer2;
499 start_m = item.get_group(1) * wg_tile_m_layer2;
502 boundary_n = (start_n + wg_tile_n_layer2) > args.matrix_n_layer2
503 ? args.matrix_n_layer2
504 : (start_n + wg_tile_n_layer2);
505 boundary_m = (start_m + wg_tile_m_layer2) > args.matrix_m_layer2
506 ? args.matrix_m_layer2
507 : (start_m + wg_tile_m_layer2);
512 work_group_layer2_t g_layer2;
513 g_layer2.init(item.get_local_linear_id());
514 mem_desc_v_t mem_desc_v;
515 mem_desc_c_t mem_desc_c;
517 mem_desc_b.init(args.matB_base, {boundary_k, boundary_m, args.matB_ld},
520 mem_desc_v.init(args.matV_base, {boundary_n, boundary_k, args.matV_ld},
523 mem_desc_c.init(args.matC_base, {boundary_n, boundary_m, args.matC_ld},
526 inner_loop_count = (
wg_tile_k + k_stride_layer2 - 1) / k_stride_layer2;
527 gemm_layer2_args_t gemm_layer2_args(
528 mem_desc_b, mem_desc_v, inner_loop_count);
529 gemm_layer2_t gemm_layer2;
530 epilogue_layer2_t epilogue_layer2;
532 matAcc_layer2_t matAcc_layer2(0);
533 gemm_layer2(g_layer2, matAcc_layer2, gemm_layer2_args,
534 gemm_layer2_slm_base, gemm_layer2_nbarr_base);
535 epilogue_layer2(g_layer2, matAcc_layer2, mem_desc_c,
536 args.epilogue_layer2_args, epilogue_layer2_slm_base,
537 epilogue_layer2_nbarr_base);
Definition limitation.hpp:738
Definition limitation.hpp:736
Definition multi_layer_perceptron.hpp:29
static cl::sycl::range< 3 > get_group_range(arguments_t &args)
Host helper function to get the expected group range under the current MLP config.
Definition multi_layer_perceptron.hpp:306
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for MLP.
Definition multi_layer_perceptron.hpp:422
static bool can_implement(arguments_t &args)
Check if the arguments can be implemented.
Definition multi_layer_perceptron.hpp:338
static __XETLA_API constexpr uint32_t get_slm_size()
Gets local memory size consumption.
Definition multi_layer_perceptron.hpp:279
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current MLP config.
Definition multi_layer_perceptron.hpp:329
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition multi_layer_perceptron.hpp:262
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current MLP config.
Definition multi_layer_perceptron.hpp:287
#define __XETLA_API
Definition common.hpp:43
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Definition multi_layer_perceptron.hpp:120
uint32_t matV_ld
Is the leading dimension (pitch) size of the matrix V in memory.
Definition multi_layer_perceptron.hpp:140
matV_base_t matV_base
Is the base address of matrix V.
Definition multi_layer_perceptron.hpp:150
arguments_t(const arguments_t &args)
Definition multi_layer_perceptron.hpp:217
uint32_t matrix_n_layer1
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:126
epilogue_layer2_args_t epilogue_layer2_args
Is the epilogue arguments of second gemm.
Definition multi_layer_perceptron.hpp:156
uint32_t matrix_k_layer2
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:130
matC_base_t matC_base
Is the base address of matrix C.
Definition multi_layer_perceptron.hpp:152
uint32_t matW_ld
Is the leading dimension (pitch) size of the matrix W in memory.
Definition multi_layer_perceptron.hpp:136
uint32_t matB_ld
Is the leading dimension (pitch) size of the matrix B in memory.
Definition multi_layer_perceptron.hpp:138
arguments_t & operator=(const arguments_t &args)
Definition multi_layer_perceptron.hpp:236
matA_base_t matA_base
Is the base address of matrix A.
Definition multi_layer_perceptron.hpp:144
arguments_t()=default
Constructs arguments with default method.
arguments_t(uint32_t matrix_m_layer1_, uint32_t matrix_k_layer1_, uint32_t matrix_n_layer1_, uint32_t matrix_m_layer2_, uint32_t matrix_k_layer2_, uint32_t matrix_n_layer2_, matA_base_t matA_base_, uint32_t matA_ld_, matW_base_t matW_base_, uint32_t matW_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matV_base_t matV_base_, uint32_t matV_ld_, matC_base_t matC_base_, uint32_t matC_ld_, epilogue_layer1_args_t epilogue_layer1_args_={}, epilogue_layer2_args_t epilogue_layer2_args_={})
Constructs arguments with initialization list.
Definition multi_layer_perceptron.hpp:186
uint32_t matrix_m_layer1
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:122
uint32_t matA_ld
Is the leading dimension (pitch) size of the matrix A in memory.
Definition multi_layer_perceptron.hpp:134
uint32_t matrix_n_layer2
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:132
matB_base_t matB_base
Is the base address of matrix B.
Definition multi_layer_perceptron.hpp:148
uint32_t matC_ld
Is the leading dimension (pitch) size of the matrix C in memory.
Definition multi_layer_perceptron.hpp:142
uint32_t matrix_k_layer1
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:124
uint32_t matrix_m_layer2
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:128
static constexpr bool host_callable
Set for device copyable.
Definition multi_layer_perceptron.hpp:165
matW_base_t matW_base
Is the base address of matrix W.
Definition multi_layer_perceptron.hpp:146
epilogue_layer1_args_t epilogue_layer1_args
Is the epilogue arguments of first gemm.
Definition multi_layer_perceptron.hpp:154