XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
unaligned_xmx_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/gemm/api.hpp"
23#include "group/gemm/compute_policy.hpp"
24
25namespace gpu::xetla::group {
26
29
31template <typename compute_attr_, typename perf_tuning_knob_,
32 typename tile_shape_, typename mem_desc_a_t_, typename mem_desc_b_t_,
33 typename pre_processing_t_, gpu_arch arch_tag_>
34class gemm_t<compute_policy_unaligned_xmx<compute_attr_, perf_tuning_knob_,
35 arch_tag_>,
36 tile_shape_, // tile shape of workgroup-level gemm
37 mem_desc_a_t_, // memory attribute of matA
38 mem_desc_b_t_, // memory attribute of matB
39 pre_processing_t_, // pre_processing functor
40 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
41public:
42 using mem_desc_a_t = mem_desc_a_t_;
43 using mem_desc_b_t = mem_desc_b_t_;
44 using tile_shape = tile_shape_;
45 using pre_processing_t = pre_processing_t_;
47 perf_tuning_knob_, arch_tag_>;
48
49 static constexpr uint32_t num_cyclic = 3;
50
51 static constexpr uint32_t k_stride = compute_policy::k_stride;
52 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
53 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
54 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
55 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
56 using work_group_t = typename tile_shape::work_group_t;
57
58 constexpr static gpu_arch arch_tag = compute_policy::arch_tag;
59
60 static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout;
61 static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout;
62 static constexpr bool is_col_major_a
63 = mem_layout_a == mem_layout::col_major;
64 static constexpr bool is_col_major_b
65 = mem_layout_b == mem_layout::col_major;
66
67private:
68 /******** set data type **********/
69 using dtype_a = typename mem_desc_a_t::dtype;
70 using dtype_b = typename mem_desc_b_t::dtype;
71 using dtype_mma_acc = typename compute_policy::dtype_mma_acc;
72 using dtype_mma_a = typename compute_policy::dtype_mma_a;
73 using dtype_mma_b = typename compute_policy::dtype_mma_b;
74
75 using check_dtype
77 dtype_a, dtype_b, dtype_mma_a, dtype_mma_b>;
78
79 /******** set memory attribute **********/
80 static constexpr mem_space mem_space_a = mem_desc_a_t::space;
81 static constexpr mem_space mem_space_b = mem_desc_b_t::space;
82
83 static constexpr bool is_local_a = mem_space_a == mem_space::local;
84 static constexpr bool is_local_b = mem_space_b == mem_space::local;
85 static constexpr tdesc_update_dir update_dir_a = is_col_major_a
88 static constexpr tdesc_update_dir update_dir_b = is_col_major_b
91
92 using check_memory
94 mem_layout_a, mem_layout_b, mem_space_a, mem_space_b>;
95
96 static constexpr uint32_t stages = compute_policy::stages;
97 static constexpr uint32_t sync_freq = compute_policy::sync_freq;
98
99 /******** set tile layout && worker scope **********/
100 static constexpr uint32_t tile_size_x_a = k_stride;
101 static constexpr uint32_t tile_size_y_a = sg_tile_m;
102 static constexpr uint32_t tile_size_x_b = sg_tile_n;
103 static constexpr uint32_t tile_size_y_b = k_stride;
104 static constexpr uint32_t tile_size_x_c = sg_tile_n;
105 static constexpr uint32_t tile_size_y_c = sg_tile_m;
106 static constexpr uint32_t block_size_x_a = compute_policy::block_size_x_a;
107 static constexpr uint32_t block_size_y_a
108 = (compute_policy::block_size_y_a > tile_size_y_a)
109 ? tile_size_y_a
110 : compute_policy::block_size_y_a;
111 static constexpr uint32_t block_size_x_b = compute_policy::block_size_x_b;
112 static constexpr uint32_t block_size_y_b = compute_policy::block_size_y_b;
113
114 using check_tile_size = group::gemm<
115 gpu_arch::Xe>::default_xmx::check_tile_size_default<dtype_mma_a,
116 tile_size_x_a, tile_size_y_a, block_size_x_a, block_size_y_a,
117 tile_size_x_b, tile_size_y_b, block_size_x_b, block_size_y_b>;
118
119 /******** set tile **********/
120 static constexpr reg_layout reg_layout_a = reg_layout::tiled;
121 using matA_tile_desc_t = subgroup::tile_desc_t<tile_size_x_a, tile_size_y_a,
122 block_size_x_a, block_size_y_a, reg_layout_a>;
123
125
127 mem_layout_a, tile_shape::wg_size_x, gpu_arch::Xe>;
128 using cooperative_tile_desc_A_t =
129 typename cooperative_helper_A_t::co_tile_desc_t;
132 cooperative_tile_desc_A_t,
133 is_local_a ? msg_type::scatter : msg_type::unaligned_2d, arch_tag>;
134
137 cooperative_tile_desc_A_t, msg_type::scatter, arch_tag>;
141
145 wg_size_x, arch_tag>;
146 static constexpr reg_layout reg_layout_b
147 = sizeof(dtype_b) < sizeof(uint32_t) ? reg_layout::vnni_tiled
149 using matB_tile_desc_t = subgroup::tile_desc_t<tile_size_x_b, tile_size_y_b,
150 block_size_x_b, block_size_y_b, reg_layout_b>;
152
154 mem_layout_b, tile_shape::wg_size_y, gpu_arch::Xe>;
155 using cooperative_tile_desc_B_t =
156 typename cooperative_helper_B_t::co_tile_desc_t;
157
159
161 cooperative_tile_desc_B_t,
162 is_local_b ? msg_type::scatter : msg_type::unaligned_2d, arch_tag>;
163
166 cooperative_tile_desc_B_t, msg_type::scatter, arch_tag>;
170
174 wg_size_y, arch_tag>;
175
176public:
178 tile_size_y_c, block_size_x_b, block_size_y_a, reg_layout::tiled>;
180
181private:
183 matA_acc_t, mma_engine::xmx, arch_tag>;
184 // static constexpr bool enable_periodic_sync = (sync_freq != 0);
185 static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0;
186 static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0;
187 static constexpr uint32_t tile_size_a
188 = tile_size_x_a * tile_size_y_a * sizeof(dtype_a);
189 static constexpr uint32_t tile_size_b
190 = tile_size_x_b * tile_size_y_b * sizeof(dtype_b);
191 static constexpr uint32_t slm_size_a = wg_size_y * tile_size_a;
192 static constexpr uint32_t slm_size_b = wg_size_x * tile_size_b;
193
194public:
195 static constexpr uint32_t barrier_count = barrier_count_x + barrier_count_y;
196
197 static constexpr uint32_t slm_size = (slm_size_a + slm_size_b) * num_cyclic;
198 static constexpr uint32_t slm_base_a = 0;
199 static constexpr uint32_t slm_base_b = 0 + slm_size_a * num_cyclic;
200
201 static constexpr msg_type msg_type_a = matA_payload_t::message_type;
202 static constexpr msg_type msg_type_b = matB_payload_t::message_type;
203
204 using pre_processing_arg_t = typename pre_processing_t::arguments_t;
205
208 struct arguments_t {
217
219 inline arguments_t() = default;
220 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
221 // Please check if you need to add self-define destructor
222 // ~arguments_t(){}
223
229 inline arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
230 uint32_t loop_count, pre_processing_arg_t args = {})
231 : matA_base_desc(matA_desc)
232 , matB_base_desc(matB_desc)
233 , inner_loop_count(loop_count)
234 , pre_processing_args(args) {}
235 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
236 // Please check if you need to add self-define destructor
237 // inline ~arguments_t(){}
238 inline arguments_t(const arguments_t &args)
239 : matA_base_desc(args.matA_base_desc)
240 , matB_base_desc(args.matB_base_desc)
241 , inner_loop_count(args.inner_loop_count)
242 , pre_processing_args(args.pre_processing_args) {}
243 inline arguments_t &operator=(const arguments_t &args) {
244 this->matA_base_desc = args.matA_base_desc;
245 this->matB_base_desc = args.matB_base_desc;
246 this->inner_loop_count = args.inner_loop_count;
247 this->pre_processing_args = args.pre_processing_args;
248 return *this;
249 }
250
256 inline void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
257 uint32_t loop_count, pre_processing_arg_t args = {}) {
258 matA_base_desc = matA_desc;
259 matB_base_desc = matB_desc;
260 inner_loop_count = loop_count;
261 pre_processing_args = args;
262 }
263 };
264
269 int32_t sg_idx = g.get_id() % wg_size_x;
270 return sg_idx * sg_tile_n;
271 }
272
277 int32_t sg_idy = g.get_id() / wg_size_x;
278 return sg_idy * sg_tile_m;
279 }
280
282 "This release function will wait until all the r/w and nbarrier "
283 "id used in this gemm have been committed. By default, it will "
284 "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. "
285 "If you call this function, please set a free barrier id or make "
286 "sure barrier_id 0 is not being occupied and you need to allocate "
287 "one more barrier count in addition to the gemm barrier counts.")
288 __XETLA_API static void release(uint8_t nbarrier_id = 0) {
289 static constexpr bool need_local_fence
290 = (mem_space_a == mem_space::local)
291 || (mem_space_b == mem_space::local);
292 if constexpr (need_local_fence) {
293 xetla_fence<memory_kind::shared_local>();
294 }
295 xetla_fence<memory_kind::untyped_global>();
296 static constexpr uint32_t wg_size = wg_size_x * wg_size_y;
297 if constexpr (wg_size > 1) {
299 nbarrier.init_nbarrier(
301 nbarrier.arrive_wait();
302 }
303 }
304
313 arguments_t args, [[maybe_unused]] uint32_t slm_base = 0,
314 uint32_t nbarrier_base = 0) {
315 int32_t sg_idx = g.get_id() % wg_size_x;
316 int32_t sg_idy = g.get_id() / wg_size_x;
317
318 XETLA_ASSERT(g.get_id() < (wg_size_x * wg_size_y),
319 "Thread id(%d) should less than wg_size(%d)", g.get_id(),
320 wg_size_x * wg_size_y);
321
322 update_sg_tile_tdesc(args, sg_idx, sg_idy);
324 matA_t matA;
325 matB_t matB;
326 partial_matA_t partial_matA;
327 partial_matB_t partial_matB;
328 // >>>>>>>>>>>>>>>>>> pre_processing init
329 pre_processing.init(g, args.pre_processing_args);
330 uint32_t base_A = slm_base_a + sg_idy * tile_size_a;
331 uint32_t base_B = slm_base_b + sg_idx * tile_size_b;
332
333 uint32_t store_idx = 0;
334 uint32_t load_idx = 0;
335
336 matA_payload_t matA_payload(args.matA_base_desc);
337 matA_payload_local_st_t matA_local_st_payload(base_A, tile_size_x_a,
338 tile_size_y_a, tile_size_x_a,
339 cooperative_helper_A_t::get_offset_x(sg_idx),
340 cooperative_helper_A_t::get_offset_y(sg_idx));
341 matA_payload_local_ld_t matA_local_ld_payload(
342 base_A, tile_size_x_a, tile_size_y_a, tile_size_x_a, 0, 0);
343
344 matB_payload_t matB_payload(args.matB_base_desc);
345 matB_payload_local_st_t matB_local_st_payload(base_B, tile_size_x_b,
346 tile_size_y_b, tile_size_x_b,
347 cooperative_helper_B_t::get_offset_x(sg_idy),
348 cooperative_helper_B_t::get_offset_y(sg_idy));
349 matB_payload_local_ld_t matB_local_ld_payload(
350 base_B, tile_size_x_b, tile_size_y_b, tile_size_x_b, 0, 0);
351
352 matA_prefetch_payload_t matA_prefetch_payload(
353 args.matA_base_desc, sg_idx);
354 matB_prefetch_payload_t matB_prefetch_payload(
355 args.matB_base_desc, sg_idy);
356
358 nbarrier_a.init_nbarrier(
359 sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
361 nbarrier_b.init_nbarrier(sg_idx + barrier_count_y + nbarrier_base,
363
364 tile_load(partial_matA, matA_payload);
365 tile_load(partial_matB, matB_payload);
366
367 tile_store(partial_matA, matA_local_st_payload);
368 tile_store(partial_matB, matB_local_st_payload);
369 store_idx++;
370
371 matA_payload.template update_tdesc<update_dir_a>(matA_t::tile_size_x);
372 matB_payload.template update_tdesc<update_dir_b>(matB_t::tile_size_y);
373 xetla_fence<memory_kind::shared_local>();
374 nbarrier_a.arrive();
375 nbarrier_b.arrive();
376#pragma unroll
377 for (uint32_t i = 1; i < num_cyclic - 1; i++) {
378 tile_load(partial_matA, matA_payload);
379 tile_load(partial_matB, matB_payload);
380
381 matA_payload.template update_tdesc<update_dir_a>(
382 matA_t::tile_size_x);
383 matB_payload.template update_tdesc<update_dir_b>(
384 matB_t::tile_size_y);
385
386 matA_local_st_payload
387 .template update_tdesc<tdesc_update_dir::y_dir>(
388 wg_size_y * matA_t::tile_size_y);
389 matB_local_st_payload
390 .template update_tdesc<tdesc_update_dir::y_dir>(
391 wg_size_x * matB_t::tile_size_y);
392
393 tile_store(partial_matA, matA_local_st_payload);
394 tile_store(partial_matB, matB_local_st_payload);
395 store_idx++;
396 }
397
398 matA_prefetch_payload.template update_tdesc<update_dir_a>(
399 matA_t::tile_size_x * (num_cyclic - 1));
400 matB_prefetch_payload.template update_tdesc<update_dir_b>(
401 matB_t::tile_size_y * (num_cyclic - 1));
402#pragma unroll
403 for (uint32_t i = 0; i < stages; i++) {
404 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
405 matA_prefetch_payload);
406 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
407 matB_prefetch_payload);
408 matA_prefetch_payload.template update_tdesc<update_dir_a>(
409 matA_t::tile_size_x);
410 matB_prefetch_payload.template update_tdesc<update_dir_b>(
411 matB_t::tile_size_y);
412 }
413
414 for (uint32_t i = 0; i < args.inner_loop_count; i++) {
415 tile_load(partial_matA, matA_payload);
416 tile_load(partial_matB, matB_payload);
417
418 matA_payload.template update_tdesc<update_dir_a>(
419 matA_t::tile_size_x);
420 matB_payload.template update_tdesc<update_dir_b>(
421 matB_t::tile_size_y);
422
423 if constexpr (stages != 0) {
424 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
425 matA_prefetch_payload);
426 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
427 matB_prefetch_payload);
428 }
429
430 nbarrier_a.wait();
431 nbarrier_b.wait();
432
433 tile_load(matA, matA_local_ld_payload);
434 tile_load(matB, matB_local_ld_payload);
435
436 load_idx = (load_idx < num_cyclic - 1) ? (load_idx + 1) : 0;
437
438 if (load_idx != 0) {
439 matA_local_ld_payload
440 .template update_tdesc<tdesc_update_dir::y_dir>(
441 wg_size_y * matA_t::tile_size_y);
442 matB_local_ld_payload
443 .template update_tdesc<tdesc_update_dir::y_dir>(
444 wg_size_x * matB_t::tile_size_y);
445 } else {
446 matA_local_ld_payload
447 .template update_tdesc<tdesc_update_dir::y_dir>(
448 (1 - num_cyclic) * wg_size_y
449 * matA_t::tile_size_y);
450 matB_local_ld_payload
451 .template update_tdesc<tdesc_update_dir::y_dir>(
452 (1 - num_cyclic) * wg_size_x
453 * matB_t::tile_size_y);
454 }
455 xetla_fence<memory_kind::shared_local>();
456
457 if constexpr (stages != 0) {
458 matA_prefetch_payload.template update_tdesc<update_dir_a>(
459 matA_t::tile_size_x);
460 matB_prefetch_payload.template update_tdesc<update_dir_b>(
461 matB_t::tile_size_y);
462 }
463
464 nbarrier_a.arrive();
465 nbarrier_b.arrive();
466 SW_BARRIER();
467 matA_acc_t matA_acc;
468 matB_acc_t matB_acc;
469 subgroup::elemwise_cvt(matA_acc, matA);
470 subgroup::vnni_transform(matB_acc, matB);
471 pre_processing(matA_acc, matB_acc, matA, matB);
472 SW_BARRIER();
473 tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc);
474 SW_BARRIER();
475
476 if (store_idx != 0) {
477 matA_local_st_payload
478 .template update_tdesc<tdesc_update_dir::y_dir>(
479 wg_size_y * matA_t::tile_size_y);
480 matB_local_st_payload
481 .template update_tdesc<tdesc_update_dir::y_dir>(
482 wg_size_x * matB_t::tile_size_y);
483 } else {
484 matA_local_st_payload
485 .template update_tdesc<tdesc_update_dir::y_dir>(
486 (1 - num_cyclic) * wg_size_y
487 * matA_t::tile_size_y);
488 matB_local_st_payload
489 .template update_tdesc<tdesc_update_dir::y_dir>(
490 (1 - num_cyclic) * wg_size_x
491 * matB_t::tile_size_y);
492 }
493
494 tile_store(partial_matA, matA_local_st_payload);
495 tile_store(partial_matB, matB_local_st_payload);
496 store_idx = (store_idx < num_cyclic - 1) ? (store_idx + 1) : 0;
497 }
498 SW_BARRIER();
499 nbarrier_a.wait();
500 nbarrier_b.wait();
501 }
502
503private:
505 __XETLA_API static void update_sg_tile_tdesc(
506 arguments_t &args, int32_t sg_idx, int32_t sg_idy) {
507 int32_t tile_offset_n = sg_idx * sg_tile_n;
508 int32_t tile_offset_m = sg_idy * sg_tile_m;
509
510 args.matA_base_desc.update_coord_y(
511 tile_offset_m + cooperative_helper_A_t::get_offset_y(sg_idx));
512 args.matA_base_desc.update_coord_x(
513 cooperative_helper_A_t::get_offset_x(sg_idx));
514 args.matB_base_desc.update_coord_x(
515 tile_offset_n + cooperative_helper_B_t::get_offset_x(sg_idy));
516 args.matB_base_desc.update_coord_y(
517 cooperative_helper_B_t::get_offset_y(sg_idy));
518 }
519};
520
522
523} // namespace gpu::xetla::group
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for gemm.
Definition unaligned_xmx_xe.hpp:312
XETLA_MARKER("This release function will wait until all the r/w and nbarrier " "id used in this gemm have been committed. By default, it will " "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") __XETLA_API static void release(uint8_t nbarrier_id=0)
Definition unaligned_xmx_xe.hpp:281
Gemm functor.
Definition api.hpp:52
Helper to do the cooperative workgroups load.
Definition cooperative_load_helper.hpp:34
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define XETLA_ASSERT(c, s,...)
Definition debug.hpp:158
C++ API.
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< is_same_layout< T_dst, T_src >::value > vnni_transform(T_dst &dst, T_src &src)
Changes vnni layout.
Definition op_function.hpp:355
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
tdesc_update_dir
Definition common.hpp:228
mem_layout
Definition common.hpp:76
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70
void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, pre_processing_arg_t args={})
Explicit initialization function.
Definition unaligned_xmx_xe.hpp:256
arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, pre_processing_arg_t args={})
Constructs a new arguments t object.
Definition unaligned_xmx_xe.hpp:229
Definition limitation.hpp:609
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the memory information to prefetch data to cache.
Definition api.hpp:53
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is the xetla tile mma operation definition API.
Definition api.hpp:36
Is a struct contains some register file.
Definition api.hpp:99
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void arrive_wait()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:80
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76