XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
default_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "kernel/gemm/api.hpp"
23#include "kernel/gemm/common.hpp"
24#include "kernel/gemm/dispatch_policy.hpp"
25
26namespace gpu::xetla::kernel {
27
30
35template <typename gemm_t_, typename epilogue_t_, typename group_swizzle_>
36class gemm_universal_t<dispatch_policy_default<group_swizzle_>, gemm_t_,
37 epilogue_t_,
38 std::enable_if_t<(group_swizzle_::arch_tag == gpu_arch::Xe)>> {
39 using gemm_t = gemm_t_;
40 using epilogue_t = epilogue_t_;
41 using gemm_args_t = typename gemm_t::arguments_t;
42 using epilogue_args_t = typename epilogue_t::arguments_t;
43 using tile_shape = typename gemm_t::tile_shape;
44 using group_swizzle_t = group_swizzle_;
45
46 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
47 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
48 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
49 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
50 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
51 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
52 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
53 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
54
55 static constexpr uint32_t k_stride = gemm_t::k_stride;
56 using work_group_t = typename gemm_t::work_group_t;
57
58 static constexpr gpu_arch arch_tag = group_swizzle_t::arch_tag;
59 static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same");
60 static_assert(
61 arch_tag == epilogue_t::arch_tag, "arch_tag should be the same");
62 static_assert(std::is_same<typename gemm_t::tile_shape,
63 typename epilogue_t::tile_shape>::value,
64 "tile_shape should be the same");
65
66 using mem_desc_a_t = typename gemm_t::mem_desc_a_t;
67 using mem_desc_b_t = typename gemm_t::mem_desc_b_t;
68 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
69 using matA_base_t = typename mem_desc_a_t::base_t;
70 using matB_base_t = typename mem_desc_b_t::base_t;
71 using matC_base_t = typename mem_desc_c_t::base_t;
72 using dtype_a = typename mem_desc_a_t::dtype;
73 using dtype_b = typename mem_desc_b_t::dtype;
74 using dtype_c = typename mem_desc_c_t::dtype;
75 using matAcc_t = typename gemm_t::matAcc_t;
76
77public:
80 struct arguments_t {
82 uint32_t matrix_m;
84 uint32_t matrix_k;
86 uint32_t matrix_n;
88 uint32_t matA_ld;
90 uint32_t matB_ld;
92 uint32_t matC_ld;
94 matA_base_t matA_base;
96 matB_base_t matB_base;
98 matC_base_t matC_base;
100 epilogue_args_t epilogue_args;
101
103 inline arguments_t() = default;
104
106 static constexpr bool host_callable = true;
107
119 inline arguments_t(uint32_t matrix_m_, uint32_t matrix_k_,
120 uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_,
121 matB_base_t matB_base_, uint32_t matB_ld_,
122 matC_base_t matC_base_, uint32_t matC_ld_,
123 epilogue_args_t epilogue_args_ = {})
124 : matrix_m(matrix_m_)
125 , matrix_k(matrix_k_)
126 , matrix_n(matrix_n_)
127 , matA_ld(matA_ld_)
128 , matB_ld(matB_ld_)
129 , matC_ld(matC_ld_)
130 , matA_base(matA_base_)
131 , matB_base(matB_base_)
132 , matC_base(matC_base_)
133 , epilogue_args(epilogue_args_) {}
134 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
135 // Please check if you need to add self-define destructor
136 // inline ~arguments_t(){}
137 inline arguments_t(const arguments_t &args)
138 : matrix_m(args.matrix_m)
139 , matrix_k(args.matrix_k)
140 , matrix_n(args.matrix_n)
141 , matA_ld(args.matA_ld)
142 , matB_ld(args.matB_ld)
143 , matC_ld(args.matC_ld)
144 , matA_base(args.matA_base)
145 , matB_base(args.matB_base)
146 , matC_base(args.matC_base)
147 , epilogue_args(args.epilogue_args) {}
148 inline arguments_t &operator=(const arguments_t &args) {
149 this->matrix_m = args.matrix_m;
150 this->matrix_k = args.matrix_k;
151 this->matrix_n = args.matrix_n;
152 this->matA_base = args.matA_base;
153 this->matA_ld = args.matA_ld;
154 this->matB_base = args.matB_base;
155 this->matB_ld = args.matB_ld;
156 this->matC_base = args.matC_base;
157 this->matC_ld = args.matC_ld;
158 this->epilogue_args = args.epilogue_args;
159 return *this;
160 }
161 };
162
166 __XETLA_API static constexpr uint32_t get_barrier_count() {
167 constexpr uint32_t count
168 = gemm_t::barrier_count + epilogue_t::barrier_count;
169 static_assert(
170 count <= 32, "The named_barrier count should be less than 32!");
171 return count;
172 }
173
177 __XETLA_API static constexpr uint32_t get_slm_size() {
178 constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size;
179 static_assert(size <= (128 * 1024),
180 "The local memory size should be less than 128KB!");
181 return size;
182 };
183
186 static cl::sycl::range<3> get_local_range() {
187 uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
188 uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
189 std::cout << "Local range: {" << 1 << ", " << local_range_m << ", "
190 << local_range_n << "} \n";
191 assert(local_range_m * local_range_n <= 32);
192 return cl::sycl::range<3> {1, local_range_m, local_range_n};
193 };
194
199 static cl::sycl::range<3> get_group_range(
200 uint32_t matrix_m, uint32_t matrix_n) {
201 uint32_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
202 uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
203 group_swizzle_t::update_group_range(group_range_m, group_range_n);
204 std::cout << "Group range: {" << 1 << ", " << group_range_m << ", "
205 << group_range_n << "} \n";
206 return cl::sycl::range<3> {1, group_range_m, group_range_n};
207 };
208
212 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
213 cl::sycl::range<3> local_range = get_local_range();
214 cl::sycl::range<3> group_range
215 = get_group_range(args.matrix_m, args.matrix_n);
216 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
217 };
218
222 static bool can_implement(arguments_t &args) {
223 bool implementable = true;
224 if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
225 if (gemm_t::msg_type_a == msg_type::block_2d) {
226 implementable &= kernel::block_2d<gpu_arch::Xe,
227 dtype_a>::check_tensor((uint64_t)(args.matA_base.base),
228 gemm_t::is_col_major_a ? args.matrix_m : args.matrix_k,
229 gemm_t::is_col_major_a ? args.matrix_k : args.matrix_m,
230 args.matA_ld);
231 } else {
232 implementable &= kernel::general_1d<gpu_arch::Xe,
233 dtype_a>::check_alignment(args.matA_base.base,
234 args.matA_ld);
235 }
236 }
237 if (gemm_t::msg_type_b != msg_type::unaligned_2d) {
238 if (gemm_t::msg_type_b == msg_type::block_2d) {
239 implementable &= kernel::block_2d<gpu_arch::Xe,
240 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
241 gemm_t::is_col_major_b ? args.matrix_k : args.matrix_n,
242 gemm_t::is_col_major_b ? args.matrix_n : args.matrix_k,
243 args.matB_ld);
244 } else {
245 implementable &= kernel::general_1d<gpu_arch::Xe,
246 dtype_b>::check_alignment(args.matB_base.base,
247 args.matB_ld);
248 }
249 }
250 if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
251 if (epilogue_t::msg_type_c == msg_type::block_2d) {
252 implementable &= kernel::block_2d<gpu_arch::Xe,
253 dtype_c>::check_tensor((uint64_t)(args.matC_base.base),
254 args.matrix_n, args.matrix_m, args.matC_ld);
255 } else {
256 implementable &= kernel::general_1d<gpu_arch::Xe,
257 dtype_c>::check_alignment(args.matC_base.base,
258 args.matC_ld);
259 }
260 }
261
262 return implementable;
263 }
264
271 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
272 const arguments_t &args, uint32_t slm_base = 0,
273 uint32_t nbarrier_base = 0) {
274 // set up workgroup level coordinates and boundaries
275 group_swizzle_t group_swizzle;
276 int start_m = group_swizzle.template get_tile_idx<1>(item) * wg_tile_m;
277 int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n;
278 int start_k = 0;
279 uint32_t wg_tile_k = args.matrix_k;
280 uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
281 ? args.matrix_n
282 : (start_n + wg_tile_n);
283 uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m
284 ? args.matrix_m
285 : (start_m + wg_tile_m);
286 uint32_t boundary_k = wg_tile_k;
287
288 uint32_t gemm_slm_base = slm_base;
289 uint32_t gemm_nbarr_base = nbarrier_base;
290 uint32_t epilogue_slm_base = gemm_slm_base + gemm_t::slm_size;
291 uint32_t epilogue_nbarr_base = gemm_nbarr_base + gemm_t::barrier_count;
292
293 // set up arguments
294 work_group_t g;
295 g.init(item.get_local_linear_id());
296 mem_desc_a_t mem_desc_a;
297 mem_desc_b_t mem_desc_b;
298 mem_desc_c_t mem_desc_c;
299 //setup for matA
300 if constexpr (mem_desc_a_t::is_local) {
301 mem_desc_a.init(args.matA_base,
302 {wg_tile_k, real_wg_tile_m, wg_tile_k}, {0, 0});
303 } else {
304 mem_desc_a.init(args.matA_base,
305 {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m});
306 }
307 //setup for matB
308 if constexpr (mem_desc_b_t::is_local) {
309 mem_desc_b.init(args.matB_base,
310 {real_wg_tile_n, wg_tile_k, real_wg_tile_n}, {0, 0});
311 } else {
312 mem_desc_b.init(args.matB_base,
313 {boundary_n, boundary_k, args.matB_ld}, {start_n, start_k});
314 }
315 //setup for matC
316 if constexpr (mem_desc_c_t::is_local) {
317 mem_desc_c.init(args.matC_base,
318 {real_wg_tile_n, real_wg_tile_m, real_wg_tile_n}, {0, 0});
319 } else {
320 mem_desc_c.init(args.matC_base,
321 {boundary_n, boundary_m, args.matC_ld}, {start_n, start_m});
322 }
323 uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
324 gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count);
325 gemm_t gemm;
326 epilogue_t epilogue;
327
328 matAcc_t matAcc(0);
329 gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base);
330 epilogue(g, matAcc, mem_desc_c, args.epilogue_args, epilogue_slm_base,
331 epilogue_nbarr_base);
332 }
333};
334
336
337} // namespace gpu::xetla::kernel
Definition limitation.hpp:738
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current GEMM_UNIVERSAL config.
Definition default_xe.hpp:212
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition default_xe.hpp:166
static cl::sycl::range< 3 > get_group_range(uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected group range under the current GEMM_UNIVERSAL config.
Definition default_xe.hpp:199
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current GEMM_UNIVERSAL config.
Definition default_xe.hpp:186
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for GEMM_UNIVERSAL.
Definition default_xe.hpp:271
GEMM_UNIVERSAL functor.
Definition api.hpp:39
Definition limitation.hpp:736
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
C++ API.
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Default GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:116
arguments_t(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matC_base_t matC_base_, uint32_t matC_ld_, epilogue_args_t epilogue_args_={})
Constructs arguments with initialization list.
Definition default_xe.hpp:119