XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
multi_layer_perceptron.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "xetla.hpp"
23
24namespace gpu::xetla::kernel {
25
26template <typename gemm_layer1_t_, typename epilogue_layer1_t_,
27 typename gemm_layer2_t_, typename epilogue_layer2_t_,
28 gpu_arch arch_tag_>
30 using gemm_layer1_t = gemm_layer1_t_;
31 using epilogue_layer1_t = epilogue_layer1_t_;
32 using gemm_layer1_args_t = typename gemm_layer1_t::arguments_t;
33 using epilogue_layer1_args_t = typename epilogue_layer1_t::arguments_t;
34
35 using tile_shape_layer1 = typename gemm_layer1_t::tile_shape;
36 static constexpr uint32_t wg_tile_m_layer1
37 = tile_shape_layer1::wg_tile_size_y;
38 static constexpr uint32_t wg_tile_n_layer1
39 = tile_shape_layer1::wg_tile_size_x;
40 static constexpr uint32_t sg_tile_m_layer1
41 = tile_shape_layer1::sg_tile_size_y;
42 static constexpr uint32_t sg_tile_n_layer1
43 = tile_shape_layer1::sg_tile_size_x;
44 static constexpr uint32_t wg_size_y_layer1 = tile_shape_layer1::wg_size_y;
45 static constexpr uint32_t wg_size_x_layer1 = tile_shape_layer1::wg_size_x;
46 static constexpr uint32_t real_wg_tile_m_layer1
47 = sg_tile_m_layer1 * wg_size_y_layer1;
48 static constexpr uint32_t real_wg_tile_n_layer1
49 = sg_tile_n_layer1 * wg_size_x_layer1;
50
51 static constexpr uint32_t k_stride_layer1 = gemm_layer1_t::k_stride;
52 using work_group_layer1_t = typename gemm_layer1_t::work_group_t;
53
54 static constexpr gpu_arch arch_tag = arch_tag_;
55 static_assert(
56 arch_tag == gemm_layer1_t::arch_tag, "arch_tag should be the same");
57 static_assert(arch_tag == epilogue_layer1_t::arch_tag,
58 "arch_tag should be the same");
59 static_assert(std::is_same<typename gemm_layer1_t::tile_shape,
60 typename epilogue_layer1_t::tile_shape>::value,
61 "tile_shape should be the same");
62
63 using mem_desc_a_t = typename gemm_layer1_t::mem_desc_a_t;
64 using mem_desc_w_t = typename gemm_layer1_t::mem_desc_b_t;
65 using mem_desc_b_t = typename epilogue_layer1_t::mem_desc_c_t;
66 using matA_base_t = typename mem_desc_a_t::base_t;
67 using matW_base_t = typename mem_desc_w_t::base_t;
68 using matB_base_t = typename mem_desc_b_t::base_t;
69 using dtype_a = typename mem_desc_a_t::dtype;
70 using dtype_w = typename mem_desc_w_t::dtype;
71 using dtype_b = typename mem_desc_b_t::dtype;
72 using matAcc_layer1_t = typename gemm_layer1_t::matAcc_t;
73
74 using gemm_layer2_t = gemm_layer2_t_;
75 using epilogue_layer2_t = epilogue_layer2_t_;
76 using gemm_layer2_args_t = typename gemm_layer2_t::arguments_t;
77 using epilogue_layer2_args_t = typename epilogue_layer2_t::arguments_t;
78
79 using tile_shape_layer2 = typename gemm_layer2_t::tile_shape;
80 static constexpr uint32_t wg_tile_m_layer2
81 = tile_shape_layer2::wg_tile_size_y;
82 static constexpr uint32_t wg_tile_n_layer2
83 = tile_shape_layer2::wg_tile_size_x;
84 static constexpr uint32_t sg_tile_m_layer2
85 = tile_shape_layer2::sg_tile_size_y;
86 static constexpr uint32_t sg_tile_n_layer2
87 = tile_shape_layer2::sg_tile_size_x;
88 static constexpr uint32_t wg_size_y_layer2 = tile_shape_layer2::wg_size_y;
89 static constexpr uint32_t wg_size_x_layer2 = tile_shape_layer2::wg_size_x;
90 static constexpr uint32_t real_wg_tile_m_layer2
91 = sg_tile_m_layer2 * wg_size_y_layer2;
92 static constexpr uint32_t real_wg_tile_n_layer2
93 = sg_tile_n_layer2 * wg_size_x_layer2;
94
95 static constexpr uint32_t k_stride_layer2 = gemm_layer2_t::k_stride;
96 using work_group_layer2_t = typename gemm_layer2_t::work_group_t;
97
98 static_assert(
99 arch_tag == gemm_layer2_t::arch_tag, "arch_tag should be the same");
100 static_assert(arch_tag == epilogue_layer2_t::arch_tag,
101 "arch_tag should be the same");
102 static_assert(std::is_same<typename gemm_layer2_t::tile_shape,
103 typename epilogue_layer2_t::tile_shape>::value,
104 "tile_shape should be the same");
105
106 // using mem_desc_b_t = typename gemm1_t::mem_desc_a_t;
107 static_assert(std::is_same<typename epilogue_layer1_t::mem_desc_c_t,
108 typename gemm_layer2_t::mem_desc_a_t>::value,
109 "the output of first gemm should be the left input og second "
110 "gemm!");
111 using mem_desc_v_t = typename gemm_layer2_t::mem_desc_b_t;
112 using mem_desc_c_t = typename epilogue_layer2_t::mem_desc_c_t;
113 using matV_base_t = typename mem_desc_v_t::base_t;
114 using matC_base_t = typename mem_desc_c_t::base_t;
115 using dtype_v = typename mem_desc_v_t::dtype;
116 using dtype_c = typename mem_desc_c_t::dtype;
117 using matAcc_layer2_t = typename gemm_layer2_t::matAcc_t;
118
119public:
120 struct arguments_t {
134 uint32_t matA_ld;
136 uint32_t matW_ld;
138 uint32_t matB_ld;
140 uint32_t matV_ld;
142 uint32_t matC_ld;
144 matA_base_t matA_base;
146 matW_base_t matW_base;
148 matB_base_t matB_base;
150 matV_base_t matV_base;
152 matC_base_t matC_base;
154 epilogue_layer1_args_t epilogue_layer1_args;
156 epilogue_layer2_args_t epilogue_layer2_args;
157
159 inline arguments_t() = default;
160 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
161 // Please check if you need to add self-define destructor
162 // ~arguments_t(){}
163
165 static constexpr bool host_callable = true;
166
186 inline arguments_t(uint32_t matrix_m_layer1_, uint32_t matrix_k_layer1_,
187 uint32_t matrix_n_layer1_, uint32_t matrix_m_layer2_,
188 uint32_t matrix_k_layer2_, uint32_t matrix_n_layer2_,
189 matA_base_t matA_base_, uint32_t matA_ld_,
190 matW_base_t matW_base_, uint32_t matW_ld_,
191 matB_base_t matB_base_, uint32_t matB_ld_,
192 matV_base_t matV_base_, uint32_t matV_ld_,
193 matC_base_t matC_base_, uint32_t matC_ld_,
194 epilogue_layer1_args_t epilogue_layer1_args_ = {},
195 epilogue_layer2_args_t epilogue_layer2_args_ = {})
196 : matrix_m_layer1(matrix_m_layer1_)
197 , matrix_k_layer1(matrix_k_layer1_)
198 , matrix_n_layer1(matrix_n_layer1_)
199 , matrix_m_layer2(matrix_m_layer2_)
200 , matrix_k_layer2(matrix_k_layer2_)
201 , matrix_n_layer2(matrix_n_layer2_)
202 , matA_ld(matA_ld_)
203 , matW_ld(matW_ld_)
204 , matB_ld(matB_ld_)
205 , matV_ld(matV_ld_)
206 , matC_ld(matC_ld_)
207 , matA_base(matA_base_)
208 , matW_base(matW_base_)
209 , matB_base(matB_base_)
210 , matV_base(matV_base_)
211 , matC_base(matC_base_)
212 , epilogue_layer1_args(epilogue_layer1_args_)
213 , epilogue_layer2_args(epilogue_layer2_args_) {}
214 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
215 // Please check if you need to add self-define destructor
216 // inline ~arguments_t(){}
217 inline arguments_t(const arguments_t &args)
224 , matA_ld(args.matA_ld)
225 , matW_ld(args.matW_ld)
226 , matB_ld(args.matB_ld)
227 , matV_ld(args.matV_ld)
228 , matC_ld(args.matC_ld)
229 , matA_base(args.matA_base)
230 , matW_base(args.matW_base)
231 , matB_base(args.matB_base)
232 , matV_base(args.matV_base)
233 , matC_base(args.matC_base)
236 inline arguments_t &operator=(const arguments_t &args) {
237 this->matrix_m_layer1 = args.matrix_m_layer1;
238 this->matrix_k_layer1 = args.matrix_k_layer1;
239 this->matrix_n_layer1 = args.matrix_n_layer1;
240 this->matrix_m_layer2 = args.matrix_m_layer2;
241 this->matrix_k_layer2 = args.matrix_k_layer2;
242 this->matrix_n_layer2 = args.matrix_n_layer2;
243 this->matA_base = args.matA_base;
244 this->matA_ld = args.matA_ld;
245 this->matW_base = args.matW_base;
246 this->matW_ld = args.matW_ld;
247 this->matB_base = args.matB_base;
248 this->matB_ld = args.matB_ld;
249 this->matV_base = args.matV_base;
250 this->matV_ld = args.matV_ld;
251 this->matC_base = args.matC_base;
252 this->matC_ld = args.matC_ld;
253 this->epilogue_layer1_args = args.epilogue_layer1_args;
254 this->epilogue_layer2_args = args.epilogue_layer2_args;
255 return *this;
256 }
257 };
258
262 __XETLA_API static constexpr uint32_t get_barrier_count() {
263 constexpr uint32_t count = gemm_layer1_t::barrier_count
264 + epilogue_layer1_t::barrier_count + 1
265 > gemm_layer2_t::barrier_count
266 + epilogue_layer2_t::barrier_count
267 ? gemm_layer1_t::barrier_count
268 + epilogue_layer1_t::barrier_count + 1
269 : gemm_layer2_t::barrier_count
270 + epilogue_layer2_t::barrier_count;
271 static_assert(
272 count <= 32, "The named_barrier count should be less than 32!");
273 return count;
274 }
275
279 __XETLA_API static constexpr uint32_t get_slm_size() {
280 // In this MLP example we don't use SLM for load/store or intermediate result storage
281 // So the final slm size should equal to 0
282 return 0;
283 };
284
287 static cl::sycl::range<3> get_local_range() {
288 // make sure first layer and second layer use same subgroup number.
289 static_assert(work_group_layer1_t::size == work_group_layer2_t::size,
290 "we should make sure first gemm and second gemm use same "
291 "subgroup number!");
292 uint32_t local_range_m
293 = (wg_tile_m_layer2 + sg_tile_m_layer2 - 1) / sg_tile_m_layer2;
294 uint32_t local_range_n
295 = (wg_tile_n_layer2 + sg_tile_n_layer2 - 1) / sg_tile_n_layer2;
296 std::cout << "Local range: {" << 1 << ", " << local_range_m << ", "
297 << local_range_n << "} \n";
298 assert(local_range_m * local_range_n <= 32);
299 return cl::sycl::range<3> {1, local_range_m, local_range_n};
300 };
301
306 static cl::sycl::range<3> get_group_range(arguments_t &args) {
307 // make sure first layer and second layer meet the condition to be fused.
308 static_assert(wg_tile_m_layer1 == wg_tile_m_layer2,
309 "first gemm and second gemm should have the same wg_tile_m");
310 assert(args.matrix_m_layer1 == args.matrix_m_layer2);
311 assert(((args.matrix_n_layer1 + wg_tile_n_layer1 - 1)
312 / wg_tile_n_layer1)
313 == 1
314 && ((args.matrix_n_layer2 + wg_tile_n_layer2 - 1)
315 / wg_tile_n_layer2)
316 == 1);
317 uint32_t group_range_m = (args.matrix_m_layer1 + wg_tile_m_layer1 - 1)
318 / wg_tile_m_layer1;
319 uint32_t group_range_n = (args.matrix_n_layer1 + wg_tile_n_layer1 - 1)
320 / wg_tile_n_layer1;
321 std::cout << "Group range: {1"
322 << ", " << group_range_m << ", " << group_range_n << "} \n";
323 return cl::sycl::range<3> {1, group_range_m, group_range_n};
324 };
325
329 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
330 cl::sycl::range<3> local_range = get_local_range();
331 cl::sycl::range<3> group_range = get_group_range(args);
332 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
333 };
334
338 static bool can_implement(arguments_t &args) {
339 bool implementable = true;
340 if (gemm_layer1_t::msg_type_a != msg_type::unaligned_2d) {
341 if (gemm_layer1_t::msg_type_a == msg_type::block_2d) {
342 implementable &= kernel::block_2d<gpu_arch::Xe,
343 dtype_a>::check_tensor((uint64_t)(args.matA_base.base),
345 args.matA_ld);
346 } else {
347 implementable &= kernel::general_1d<gpu_arch::Xe,
348 dtype_a>::check_alignment(args.matA_base.base,
349 args.matA_ld);
350 }
351 }
352 if (gemm_layer1_t::msg_type_b != msg_type::unaligned_2d) {
353 if (gemm_layer1_t::msg_type_b == msg_type::block_2d) {
354 implementable &= kernel::block_2d<gpu_arch::Xe,
355 dtype_w>::check_tensor((uint64_t)(args.matW_base.base),
357 args.matW_ld);
358 } else {
359 implementable &= kernel::general_1d<gpu_arch::Xe,
360 dtype_w>::check_alignment(args.matW_base.base,
361 args.matW_ld);
362 }
363 }
364 if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) {
365 if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) {
366 implementable &= kernel::block_2d<gpu_arch::Xe,
367 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
369 args.matB_ld);
370 } else {
371 implementable &= kernel::general_1d<gpu_arch::Xe,
372 dtype_b>::check_alignment(args.matB_base.base,
373 args.matB_ld);
374 }
375 }
376 if (gemm_layer2_t::msg_type_a != msg_type::unaligned_2d) {
377 if (gemm_layer2_t::msg_type_a == msg_type::block_2d) {
378 implementable &= kernel::block_2d<gpu_arch::Xe,
379 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
381 args.matB_ld);
382 } else {
383 implementable &= kernel::general_1d<gpu_arch::Xe,
384 dtype_a>::check_alignment(args.matB_base.base,
385 args.matB_ld);
386 }
387 }
388 if (gemm_layer2_t::msg_type_b != msg_type::unaligned_2d) {
389 if (gemm_layer2_t::msg_type_b == msg_type::block_2d) {
390 implementable &= kernel::block_2d<gpu_arch::Xe,
391 dtype_v>::check_tensor((uint64_t)(args.matV_base.base),
393 args.matV_ld);
394 } else {
395 implementable &= kernel::general_1d<gpu_arch::Xe,
396 dtype_v>::check_alignment(args.matV_base.base,
397 args.matV_ld);
398 }
399 }
400 if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) {
401 if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) {
402 implementable &= kernel::block_2d<gpu_arch::Xe,
403 dtype_c>::check_tensor((uint64_t)(args.matC_base.base),
405 args.matC_ld);
406 } else {
407 implementable &= kernel::general_1d<gpu_arch::Xe,
408 dtype_c>::check_alignment(args.matC_base.base,
409 args.matC_ld);
410 }
411 }
412
413 return implementable;
414 }
415
422 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
423 const arguments_t &args, uint32_t slm_base = 0,
424 uint32_t nbarrier_base = 0) {
425 // set up workgroup level coordinates and boundaries
426 int start_n = item.get_group(2) * wg_tile_n_layer1;
427 int start_m = item.get_group(1) * wg_tile_m_layer1;
428 int start_k = 0;
429 uint32_t wg_tile_k = args.matrix_k_layer1;
430 uint32_t boundary_n
431 = (start_n + wg_tile_n_layer1) > args.matrix_n_layer1
432 ? args.matrix_n_layer1
433 : (start_n + wg_tile_n_layer1);
434 uint32_t boundary_m
435 = (start_m + wg_tile_m_layer1) > args.matrix_m_layer1
436 ? args.matrix_m_layer1
437 : (start_m + wg_tile_m_layer1);
438 uint32_t boundary_k = wg_tile_k;
439
440 uint32_t gemm_layer1_nbarr_base = nbarrier_base;
441 uint32_t epilogue_layer1_nbarr_base
442 = gemm_layer1_nbarr_base + gemm_layer1_t::barrier_count;
443 uint32_t global_nbarr_base
444 = epilogue_layer1_nbarr_base + epilogue_layer1_t::barrier_count;
445 // Reuse named barrier
446 uint32_t gemm_layer2_nbarr_base = nbarrier_base;
447 uint32_t epilogue_layer2_nbarr_base
448 = gemm_layer2_nbarr_base + gemm_layer2_t::barrier_count;
449
450 uint32_t gemm_layer1_slm_base = slm_base;
451 uint32_t epilogue_layer1_slm_base
452 = gemm_layer1_slm_base + gemm_layer1_t::slm_size;
453 uint32_t gemm_layer2_slm_base
454 = epilogue_layer1_slm_base + epilogue_layer2_t::slm_size;
455 uint32_t epilogue_layer2_slm_base
456 = gemm_layer2_slm_base + gemm_layer2_t::slm_size;
457
458 // set up arguments
459 work_group_layer1_t g_layer1;
460 g_layer1.init(item.get_local_linear_id());
461 mem_desc_a_t mem_desc_a;
462 mem_desc_w_t mem_desc_w;
463 mem_desc_b_t mem_desc_b;
464 //setup for matA
465 mem_desc_a.init(args.matA_base, {boundary_k, boundary_m, args.matA_ld},
466 {start_k, start_m});
467 //setup for matB
468 mem_desc_w.init(args.matW_base, {boundary_n, boundary_k, args.matW_ld},
469 {start_n, start_k});
470 //setup for matC
471 mem_desc_b.init(args.matB_base, {boundary_n, boundary_m, args.matB_ld},
472 {start_n, start_m});
473
474 uint32_t inner_loop_count
475 = (wg_tile_k + k_stride_layer1 - 1) / k_stride_layer1;
476 gemm_layer1_args_t gemm_layer1_args(
477 mem_desc_a, mem_desc_w, inner_loop_count);
478 gemm_layer1_t gemm_layer1;
479 epilogue_layer1_t epilogue_layer1;
480
481 matAcc_layer1_t matAcc_layer1(0);
482 gemm_layer1(g_layer1, matAcc_layer1, gemm_layer1_args,
483 gemm_layer1_slm_base, gemm_layer1_nbarr_base);
484 epilogue_layer1(g_layer1, matAcc_layer1, mem_desc_b,
485 args.epilogue_layer1_args, epilogue_layer1_slm_base,
486 epilogue_layer1_nbarr_base);
487
488 // fence & barrier between two gemm
489 xetla_fence();
490 xetla_nbarrier_t<work_group_layer2_t::size, work_group_layer2_t::size,
491 gpu_arch::Xe>
492 nbarrier_global;
493 nbarrier_global.init_nbarrier(
494 global_nbarr_base, nbarrier_role::producer_consumer);
495 nbarrier_global.arrive_wait();
496
497 // set up workgroup level coordinates and boundaries
498 start_n = item.get_group(2) * wg_tile_n_layer2;
499 start_m = item.get_group(1) * wg_tile_m_layer2;
500 start_k = 0;
501 wg_tile_k = args.matrix_k_layer2;
502 boundary_n = (start_n + wg_tile_n_layer2) > args.matrix_n_layer2
503 ? args.matrix_n_layer2
504 : (start_n + wg_tile_n_layer2);
505 boundary_m = (start_m + wg_tile_m_layer2) > args.matrix_m_layer2
506 ? args.matrix_m_layer2
507 : (start_m + wg_tile_m_layer2);
508 boundary_k = wg_tile_k;
509
510 // set up arguments
511 // reuse mem_desc_b
512 work_group_layer2_t g_layer2;
513 g_layer2.init(item.get_local_linear_id());
514 mem_desc_v_t mem_desc_v;
515 mem_desc_c_t mem_desc_c;
516 //setup for matA
517 mem_desc_b.init(args.matB_base, {boundary_k, boundary_m, args.matB_ld},
518 {start_k, start_m});
519 //setup for matB
520 mem_desc_v.init(args.matV_base, {boundary_n, boundary_k, args.matV_ld},
521 {start_n, start_k});
522 //setup for matC
523 mem_desc_c.init(args.matC_base, {boundary_n, boundary_m, args.matC_ld},
524 {start_n, start_m});
525
526 inner_loop_count = (wg_tile_k + k_stride_layer2 - 1) / k_stride_layer2;
527 gemm_layer2_args_t gemm_layer2_args(
528 mem_desc_b, mem_desc_v, inner_loop_count);
529 gemm_layer2_t gemm_layer2;
530 epilogue_layer2_t epilogue_layer2;
531
532 matAcc_layer2_t matAcc_layer2(0);
533 gemm_layer2(g_layer2, matAcc_layer2, gemm_layer2_args,
534 gemm_layer2_slm_base, gemm_layer2_nbarr_base);
535 epilogue_layer2(g_layer2, matAcc_layer2, mem_desc_c,
536 args.epilogue_layer2_args, epilogue_layer2_slm_base,
537 epilogue_layer2_nbarr_base);
538 }
539};
540
542
543} // namespace gpu::xetla::kernel
Definition limitation.hpp:738
Definition limitation.hpp:736
Definition multi_layer_perceptron.hpp:29
static cl::sycl::range< 3 > get_group_range(arguments_t &args)
Host helper function to get the expected group range under the current MLP config.
Definition multi_layer_perceptron.hpp:306
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for MLP.
Definition multi_layer_perceptron.hpp:422
static bool can_implement(arguments_t &args)
Check if the arguments can be implemented.
Definition multi_layer_perceptron.hpp:338
static __XETLA_API constexpr uint32_t get_slm_size()
Gets local memory size consumption.
Definition multi_layer_perceptron.hpp:279
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current MLP config.
Definition multi_layer_perceptron.hpp:329
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition multi_layer_perceptron.hpp:262
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current MLP config.
Definition multi_layer_perceptron.hpp:287
#define __XETLA_API
Definition common.hpp:43
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Definition multi_layer_perceptron.hpp:120
uint32_t matV_ld
Is the leading dimension (pitch) size of the matrix V in memory.
Definition multi_layer_perceptron.hpp:140
matV_base_t matV_base
Is the base address of matrix V.
Definition multi_layer_perceptron.hpp:150
arguments_t(const arguments_t &args)
Definition multi_layer_perceptron.hpp:217
uint32_t matrix_n_layer1
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:126
epilogue_layer2_args_t epilogue_layer2_args
Is the epilogue arguments of second gemm.
Definition multi_layer_perceptron.hpp:156
uint32_t matrix_k_layer2
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:130
matC_base_t matC_base
Is the base address of matrix C.
Definition multi_layer_perceptron.hpp:152
uint32_t matW_ld
Is the leading dimension (pitch) size of the matrix W in memory.
Definition multi_layer_perceptron.hpp:136
uint32_t matB_ld
Is the leading dimension (pitch) size of the matrix B in memory.
Definition multi_layer_perceptron.hpp:138
arguments_t & operator=(const arguments_t &args)
Definition multi_layer_perceptron.hpp:236
matA_base_t matA_base
Is the base address of matrix A.
Definition multi_layer_perceptron.hpp:144
arguments_t()=default
Constructs arguments with default method.
arguments_t(uint32_t matrix_m_layer1_, uint32_t matrix_k_layer1_, uint32_t matrix_n_layer1_, uint32_t matrix_m_layer2_, uint32_t matrix_k_layer2_, uint32_t matrix_n_layer2_, matA_base_t matA_base_, uint32_t matA_ld_, matW_base_t matW_base_, uint32_t matW_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matV_base_t matV_base_, uint32_t matV_ld_, matC_base_t matC_base_, uint32_t matC_ld_, epilogue_layer1_args_t epilogue_layer1_args_={}, epilogue_layer2_args_t epilogue_layer2_args_={})
Constructs arguments with initialization list.
Definition multi_layer_perceptron.hpp:186
uint32_t matrix_m_layer1
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:122
uint32_t matA_ld
Is the leading dimension (pitch) size of the matrix A in memory.
Definition multi_layer_perceptron.hpp:134
uint32_t matrix_n_layer2
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:132
matB_base_t matB_base
Is the base address of matrix B.
Definition multi_layer_perceptron.hpp:148
uint32_t matC_ld
Is the leading dimension (pitch) size of the matrix C in memory.
Definition multi_layer_perceptron.hpp:142
uint32_t matrix_k_layer1
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:124
uint32_t matrix_m_layer2
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition multi_layer_perceptron.hpp:128
static constexpr bool host_callable
Set for device copyable.
Definition multi_layer_perceptron.hpp:165
matW_base_t matW_base
Is the base address of matrix W.
Definition multi_layer_perceptron.hpp:146
epilogue_layer1_args_t epilogue_layer1_args
Is the epilogue arguments of first gemm.
Definition multi_layer_perceptron.hpp:154
C++ API.