XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
default_fpu_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/gemm/api.hpp"
23#include "group/gemm/compute_policy.hpp"
24
25namespace gpu::xetla::group {
26
29
31template <typename compute_attr_, typename perf_tuning_knob_,
32 typename tile_shape_, typename mem_desc_a_t_, typename mem_desc_b_t_,
33 typename pre_processing_t_, gpu_arch arch_tag_>
34class gemm_t<
35 compute_policy_default_fpu<compute_attr_, perf_tuning_knob_, arch_tag_>,
36 tile_shape_, // tile shape of workgroup-level
37 mem_desc_a_t_, // memory attribute of matA
38 mem_desc_b_t_, // memory attribute of matB
39 pre_processing_t_, // pre_processing functor
40 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
41public:
42 using mem_desc_a_t = mem_desc_a_t_;
43 using mem_desc_b_t = mem_desc_b_t_;
44 using tile_shape = tile_shape_;
45 using pre_processing_t = pre_processing_t_;
47 perf_tuning_knob_, arch_tag_>;
48 static constexpr uint32_t k_stride = compute_policy::k_stride;
49 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
50 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
51 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
52 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
53 using work_group_t = typename tile_shape::work_group_t;
54 constexpr static gpu_arch arch_tag = compute_policy::arch_tag;
55
56 static constexpr mem_layout mem_layout_a = mem_desc_a_t::layout;
57 static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout;
58 static constexpr bool is_col_major_a
59 = mem_layout_a == mem_layout::col_major;
60 static constexpr bool is_col_major_b
61 = mem_layout_b == mem_layout::col_major;
62
63private:
64 /******** set data type **********/
65 using dtype_a = typename mem_desc_a_t::dtype;
66 using dtype_b = typename mem_desc_b_t::dtype;
67 using dtype_mma_acc = typename compute_policy::dtype_mma_acc;
68 using dtype_mma_a = typename compute_policy::dtype_mma_a;
69 using dtype_mma_b = typename compute_policy::dtype_mma_b;
70
71 using check_dtype
73 dtype_a, dtype_b, dtype_mma_a, dtype_mma_b, dtype_mma_acc>;
74
75 /******** set memory attribute **********/
76 static constexpr mem_space mem_space_a = mem_desc_a_t::space;
77 static constexpr mem_space mem_space_b = mem_desc_b_t::space;
78
79 static constexpr bool is_local_a = mem_space_a == mem_space::local;
80 static constexpr bool is_local_b = mem_space_b == mem_space::local;
81 static constexpr tdesc_update_dir update_dir_a = is_col_major_a
84 static constexpr tdesc_update_dir update_dir_b = is_col_major_b
87
88 using check_memory
90 mem_layout_a, mem_layout_b, mem_space_a, mem_space_b>;
91
92 static constexpr uint32_t stages = compute_policy::stages;
93 static constexpr uint32_t sync_freq = compute_policy::sync_freq;
94
95 /******** set tile layout && worker scope **********/
96 static constexpr uint32_t tile_size_x_a = k_stride;
97 static constexpr uint32_t tile_size_y_a = sg_tile_m;
98 static constexpr uint32_t tile_size_x_b = sg_tile_n;
99 static constexpr uint32_t tile_size_y_b = k_stride;
100 static constexpr uint32_t tile_size_x_c = sg_tile_n;
101 static constexpr uint32_t tile_size_y_c = sg_tile_m;
102
103 static constexpr uint32_t block_size_x_a
104 = (compute_policy::block_size_x_a > tile_size_x_a)
105 ? tile_size_x_a
106 : compute_policy::block_size_x_a;
107 static constexpr uint32_t block_size_y_a
108 = (compute_policy::block_size_y_a > tile_size_y_a)
109 ? tile_size_y_a
110 : compute_policy::block_size_y_a;
111 static constexpr uint32_t block_size_x_b
112 = (compute_policy::block_size_x_b > tile_size_x_b)
113 ? tile_size_x_b
114 : compute_policy::block_size_x_b;
115 static constexpr uint32_t block_size_y_b
116 = (compute_policy::block_size_y_b > tile_size_y_b)
117 ? tile_size_y_b
118 : compute_policy::block_size_y_b;
119
120 using check_tile_size = group::gemm<
121 gpu_arch::Xe>::default_fpu::check_tile_size_default<dtype_mma_acc,
122 tile_size_x_a, tile_size_y_a, block_size_x_a, block_size_y_a,
123 tile_size_x_b, tile_size_y_b, block_size_x_b, block_size_y_b>;
124
125 /******** set tile **********/
126 // transpose in reg for src suppression
127 static constexpr reg_layout reg_layout_a = reg_layout::transpose_tiled;
128 using matA_tile_desc_t = subgroup::tile_desc_t<tile_size_x_a, tile_size_y_a,
129 block_size_x_a, block_size_y_a, reg_layout_a>;
133 // the tile size in register may bigger than in memory because of the padding
137 arch_tag>;
138
139 static constexpr reg_layout reg_layout_b = reg_layout::tiled;
140 using matB_tile_desc_t = subgroup::tile_desc_t<tile_size_x_b, tile_size_y_b,
141 block_size_x_b, block_size_y_b, reg_layout_b>;
148 arch_tag>;
149
150public:
152 tile_size_y_c, block_size_x_b, block_size_y_a, reg_layout::tiled>;
154
155private:
157 matA_acc_t, mma_engine::fpu, arch_tag>;
158 static constexpr bool enable_periodic_sync = (sync_freq != 0);
159 static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0;
160 static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0;
161
162public:
163 static constexpr uint32_t barrier_count
164 = enable_periodic_sync ? barrier_count_x + barrier_count_y : 0;
165 // current no slm path
166 static constexpr uint32_t slm_size = 0;
167
168 static constexpr msg_type msg_type_a = matA_payload_t::message_type;
169 static constexpr msg_type msg_type_b = matB_payload_t::message_type;
170
171 using pre_processing_arg_t = typename pre_processing_t::arguments_t;
172
175 struct arguments_t {
184
186 inline arguments_t() = default;
187
193 inline arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
194 uint32_t loop_count, pre_processing_arg_t args = {})
195 : matA_base_desc(matA_desc)
196 , matB_base_desc(matB_desc)
197 , inner_loop_count(loop_count)
198 , pre_processing_args(args) {}
199 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
200 // Please check if you need to add self-define destructor
201 // inline ~arguments_t(){}
202 inline arguments_t(const arguments_t &args)
203 : matA_base_desc(args.matA_base_desc)
204 , matB_base_desc(args.matB_base_desc)
205 , inner_loop_count(args.inner_loop_count)
206 , pre_processing_args(args.pre_processing_args) {}
207 inline arguments_t &operator=(const arguments_t &args) {
208 this->matA_base_desc = args.matA_base_desc;
209 this->matB_base_desc = args.matB_base_desc;
210 this->inner_loop_count = args.inner_loop_count;
211 this->pre_processing_args = args.pre_processing_args;
212 return *this;
213 }
214
220 inline void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc,
221 uint32_t loop_count, pre_processing_arg_t args = {}) {
222 matA_base_desc = matA_desc;
223 matB_base_desc = matB_desc;
224 inner_loop_count = loop_count;
225 pre_processing_args = args;
226 }
227 };
228
233 int32_t sg_idx = g.get_id() % wg_size_x;
234 return sg_idx * sg_tile_n;
235 }
236
241 int32_t sg_idy = g.get_id() / wg_size_x;
242 return sg_idy * sg_tile_m;
243 }
244
246 "This release function will wait until all the r/w and nbarrier "
247 "id used in this gemm have been committed. By default, it will "
248 "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. "
249 "If you call this function, please set a free barrier id or make "
250 "sure barrier_id 0 is not being occupied and you need to allocate "
251 "one more barrier count in addition to the gemm barrier counts.")
252 __XETLA_API static void release(uint8_t nbarrier_id = 0) {
253 static constexpr bool need_local_fence
254 = (mem_space_a == mem_space::local)
255 || (mem_space_b == mem_space::local);
256 if constexpr (need_local_fence) {
257 xetla_fence<memory_kind::shared_local>();
258 }
259 xetla_fence<memory_kind::untyped_global>();
260 static constexpr uint32_t wg_size = wg_size_x * wg_size_y;
261 if constexpr (wg_size > 1) {
263 nbarrier.init_nbarrier(
265 nbarrier.arrive_wait();
266 }
267 }
268
277 arguments_t args, [[maybe_unused]] uint32_t slm_base = 0,
278 uint32_t nbarrier_base = 0) {
279 int32_t sg_idx = g.get_id() % wg_size_x;
280 int32_t sg_idy = g.get_id() / wg_size_x;
281 update_sg_tile_tdesc(args, sg_idx, sg_idy);
283 matA_t matA;
284 matB_t matB;
285 // >>>>>>>>>>>>>>>>>> pre_processing init
286 pre_processing.init(g, args.pre_processing_args);
287 matA_payload_t matA_payload(args.matA_base_desc);
288 matB_payload_t matB_payload(args.matB_base_desc);
289 matA_prefetch_payload_t matA_prefetch_payload(args.matA_base_desc, 0);
290 matB_prefetch_payload_t matB_prefetch_payload(args.matB_base_desc, 0);
292 nbarrier_a.init_nbarrier(
293 sg_idy + nbarrier_base, nbarrier_role::producer_consumer);
295 nbarrier_b.init_nbarrier(sg_idx + barrier_count_y + nbarrier_base,
297#pragma unroll
298 for (uint32_t i = 0; i < stages; i++) {
299 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
300 matA_prefetch_payload);
301 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
302 matB_prefetch_payload);
303 matA_prefetch_payload.template update_tdesc<update_dir_a>(
304 matA_t::tile_size_x);
305 matB_prefetch_payload.template update_tdesc<update_dir_b>(
306 matB_t::tile_size_y);
307 }
308
309 for (uint32_t i = 0; i < args.inner_loop_count; i++) {
310 if constexpr (enable_periodic_sync) {
311 if ((i % sync_freq) == 0) {
312 if constexpr (wg_size_x > 1) { nbarrier_a.arrive(); }
313 if constexpr (wg_size_y > 1) { nbarrier_b.arrive(); }
314 }
315 }
316 SW_BARRIER();
317 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
318 matA, matA_payload);
319 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
320 matB, matB_payload);
321 matA_payload.template update_tdesc<update_dir_a>(
322 matA_t::tile_size_x);
323 matB_payload.template update_tdesc<update_dir_b>(
324 matB_t::tile_size_y);
325 if constexpr (stages != 0) {
326 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
327 matA_prefetch_payload);
328 subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
329 matB_prefetch_payload);
330 matA_prefetch_payload.template update_tdesc<update_dir_a>(
331 matA_t::tile_size_x);
332 matB_prefetch_payload.template update_tdesc<update_dir_b>(
333 matB_t::tile_size_y);
334 }
335 matA_acc_t matA_acc;
336 matB_acc_t matB_acc;
337 subgroup::elemwise_cvt(matA_acc, matA);
338 subgroup::elemwise_cvt(matB_acc, matB);
339 pre_processing(matA_acc, matB_acc, matA, matB);
340 SW_BARRIER();
341 tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc);
342 SW_BARRIER();
343 if constexpr (enable_periodic_sync) {
344 if ((i % sync_freq) == 0) {
345 if constexpr (wg_size_x > 1) { nbarrier_a.wait(); }
346 if constexpr (wg_size_y > 1) { nbarrier_b.wait(); }
347 }
348 }
349 }
350 SW_BARRIER();
351 }
352
353private:
355 __XETLA_API static void update_sg_tile_tdesc(
356 arguments_t &args, int32_t sg_idx, int32_t sg_idy) {
357 int32_t tile_offset_n = sg_idx * sg_tile_n;
358 int32_t tile_offset_m = sg_idy * sg_tile_m;
359
360 args.matA_base_desc.update_coord_y(tile_offset_m);
361 args.matB_base_desc.update_coord_x(tile_offset_n);
362 }
363};
365
366} // namespace gpu::xetla::group
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for gemm.
Definition default_fpu_xe.hpp:276
XETLA_MARKER("This release function will wait until all the r/w and nbarrier " "id used in this gemm have been committed. By default, it will " "use barrier_id 0 to do the entire workgroup sync if wg_size > 1. " "If you call this function, please set a free barrier id or make " "sure barrier_id 0 is not being occupied and you need to allocate " "one more barrier count in addition to the gemm barrier counts.") __XETLA_API static void release(uint8_t nbarrier_id=0)
Definition default_fpu_xe.hpp:245
Gemm functor.
Definition api.hpp:52
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
C++ API.
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
tdesc_update_dir
Definition common.hpp:228
mem_layout
Definition common.hpp:76
Compute policy for fpu engine.
Definition compute_policy.hpp:105
arguments_t(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, pre_processing_arg_t args={})
Constructs a new arguments t object.
Definition default_fpu_xe.hpp:193
void init(mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, uint32_t loop_count, pre_processing_arg_t args={})
Explicit initialization function.
Definition default_fpu_xe.hpp:220
Definition limitation.hpp:609
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the memory information to prefetch data to cache.
Definition api.hpp:53
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is the xetla tile mma operation definition API.
Definition api.hpp:36
Is a struct contains some register file.
Definition api.hpp:99
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void arrive_wait()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:80
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76