XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
batch_gemm.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "xetla.hpp"
23
24namespace gpu::xetla::kernel {
25
26template <typename gemm_t_, typename epilogue_t_, gpu_arch arch_tag_>
28 using gemm_t = gemm_t_;
29 using epilogue_t = epilogue_t_;
30 using gemm_args_t = typename gemm_t::arguments_t;
31 using epilogue_args_t = typename epilogue_t::arguments_t;
32
33 using tile_shape = typename gemm_t::tile_shape;
34 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
35 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
36 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
37 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
38 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
39 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
40 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
41 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
42
43 static constexpr uint32_t k_stride = gemm_t::k_stride;
44 using work_group_t = typename gemm_t::work_group_t;
45
46 static constexpr gpu_arch arch_tag = arch_tag_;
47 static_assert(arch_tag == gemm_t::arch_tag, "arch_tag should be the same");
48 static_assert(
49 arch_tag == epilogue_t::arch_tag, "arch_tag should be the same");
50 static_assert(std::is_same<typename gemm_t::tile_shape,
51 typename epilogue_t::tile_shape>::value,
52 "tile_shape should be the same");
53
54 using mem_desc_a_t = typename gemm_t::mem_desc_a_t;
55 using mem_desc_b_t = typename gemm_t::mem_desc_b_t;
56 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
57 using matA_base_t = typename mem_desc_a_t::base_t;
58 using matB_base_t = typename mem_desc_b_t::base_t;
59 using matC_base_t = typename mem_desc_c_t::base_t;
60 using dtype_a = typename mem_desc_a_t::dtype;
61 using dtype_b = typename mem_desc_b_t::dtype;
62 using dtype_c = typename mem_desc_c_t::dtype;
63 using matAcc_t = typename gemm_t::matAcc_t;
64
65 static_assert((!gemm_t::is_col_major_a) && (!gemm_t::is_col_major_b),
66 "only support row-major");
67
68public:
71 struct arguments_t {
73 uint32_t batch_size;
75 uint32_t matrix_m;
77 uint32_t matrix_k;
79 uint32_t matrix_n;
81 uint32_t matA_ld;
83 uint32_t matB_ld;
85 uint32_t matC_ld;
87 matA_base_t matA_base;
89 matB_base_t matB_base;
91 matC_base_t matC_base;
93 epilogue_args_t epilogue_args;
94
96 inline arguments_t() = default;
97
99 static constexpr bool host_callable = true;
100
113 inline arguments_t(uint32_t batch_size_, uint32_t matrix_m_,
114 uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_,
115 uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_,
116 matC_base_t matC_base_, uint32_t matC_ld_,
117 epilogue_args_t epilogue_args_ = {})
118 : batch_size(batch_size_)
119 , matrix_m(matrix_m_)
120 , matrix_k(matrix_k_)
121 , matrix_n(matrix_n_)
122 , matA_ld(matA_ld_)
123 , matB_ld(matB_ld_)
124 , matC_ld(matC_ld_)
125 , matA_base(matA_base_)
126 , matB_base(matB_base_)
127 , matC_base(matC_base_)
128 , epilogue_args(epilogue_args_) {}
129 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
130 // Please check if you need to add self-define destructor
131 // inline ~arguments_t(){}
132 inline arguments_t(const arguments_t &args)
133 : batch_size(args.batch_size)
134 , matrix_m(args.matrix_m)
135 , matrix_k(args.matrix_k)
136 , matrix_n(args.matrix_n)
137 , matA_ld(args.matA_ld)
138 , matB_ld(args.matB_ld)
139 , matC_ld(args.matC_ld)
140 , matA_base(args.matA_base)
141 , matB_base(args.matB_base)
142 , matC_base(args.matC_base)
144 inline arguments_t &operator=(const arguments_t &args) {
145 this->batch_size = args.batch_size;
146 this->matrix_m = args.matrix_m;
147 this->matrix_k = args.matrix_k;
148 this->matrix_n = args.matrix_n;
149 this->matA_base = args.matA_base;
150 this->matA_ld = args.matA_ld;
151 this->matB_base = args.matB_base;
152 this->matB_ld = args.matB_ld;
153 this->matC_base = args.matC_base;
154 this->matC_ld = args.matC_ld;
155 this->epilogue_args = args.epilogue_args;
156 return *this;
157 }
158 };
159
163 __XETLA_API static constexpr uint32_t get_barrier_count() {
164 constexpr uint32_t count
165 = gemm_t::barrier_count + epilogue_t::barrier_count;
166 static_assert(
167 count <= 32, "The named_barrier count should be less than 32!");
168 return count;
169 }
170
174 __XETLA_API static constexpr uint32_t get_slm_size() {
175 constexpr uint32_t size = gemm_t::slm_size + epilogue_t::slm_size;
176 static_assert(size <= (128 * 1024),
177 "The local memory size should be less than 128KB!");
178 return size;
179 };
180
183 static cl::sycl::range<3> get_local_range() {
184 uint32_t local_range_m = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
185 uint32_t local_range_n = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
186 std::cout << "Local range: {" << 1 << ", " << local_range_m << ", "
187 << local_range_n << "} \n";
188 assert(local_range_m * local_range_n <= 32);
189 return cl::sycl::range<3> {1, local_range_m, local_range_n};
190 };
191
196 static cl::sycl::range<3> get_group_range(
197 uint32_t batch_size, uint32_t matrix_m, uint32_t matrix_n) {
198 uint32_t group_range_m = (matrix_m + wg_tile_m - 1) / wg_tile_m;
199 uint32_t group_range_n = (matrix_n + wg_tile_n - 1) / wg_tile_n;
200 std::cout << "Group range: {" << batch_size << ", " << group_range_m
201 << ", " << group_range_n << "} \n";
202 return cl::sycl::range<3> {batch_size, group_range_m, group_range_n};
203 };
204
208 static cl::sycl::nd_range<3> get_nd_range(arguments_t &args) {
209 cl::sycl::range<3> local_range = get_local_range();
210 cl::sycl::range<3> group_range = get_group_range(
211 args.batch_size, args.matrix_m, args.matrix_n);
212 return cl::sycl::nd_range<3> {group_range * local_range, local_range};
213 };
214
218 static bool can_implement(arguments_t &args) {
219 bool implementable = true;
220 if (gemm_t::msg_type_a != msg_type::unaligned_2d) {
221 if (gemm_t::msg_type_a == msg_type::block_2d) {
222 implementable &= kernel::block_2d<gpu_arch::Xe,
223 dtype_a>::check_tensor((uint64_t)(args.matA_base.base),
224 args.matrix_k, args.matrix_m * args.batch_size,
225 args.matA_ld);
226 } else {
227 implementable &= kernel::general_1d<gpu_arch::Xe,
228 dtype_a>::check_alignment(args.matA_base.base,
229 args.matA_ld);
230 }
231 }
232 if (gemm_t::msg_type_b != msg_type::unaligned_2d) {
233 if (gemm_t::msg_type_b == msg_type::block_2d) {
234 implementable &= kernel::block_2d<gpu_arch::Xe,
235 dtype_b>::check_tensor((uint64_t)(args.matB_base.base),
236 args.matrix_n, args.matrix_k * args.batch_size,
237 args.matB_ld);
238 } else {
239 implementable &= kernel::general_1d<gpu_arch::Xe,
240 dtype_b>::check_alignment(args.matB_base.base,
241 args.matB_ld);
242 }
243 }
244 if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
245 if (epilogue_t::msg_type_c == msg_type::block_2d) {
246 implementable &= kernel::block_2d<gpu_arch::Xe,
247 dtype_c>::check_tensor((uint64_t)(args.matC_base.base),
248 args.matrix_n, args.matrix_m * args.batch_size,
249 args.matC_ld);
250 } else {
251 implementable &= kernel::general_1d<gpu_arch::Xe,
252 dtype_c>::check_alignment(args.matC_base.base,
253 args.matC_ld);
254 }
255 }
256
257 return implementable;
258 }
259
266 __XETLA_API KERNEL_FUNC void operator()(sycl::nd_item<3> &item,
267 const arguments_t &args, uint32_t slm_base = 0,
268 uint32_t nbarrier_base = 0) {
269 // set up workgroup level coordinates and boundaries
270 int batch_id = item.get_group(0);
271 int start_n = item.get_group(2) * wg_tile_n;
272 int start_m = item.get_group(1) * wg_tile_m + batch_id * args.matrix_m;
273 int start_k_a = 0;
274 int start_k_b = batch_id * args.matrix_k;
275 uint32_t wg_tile_k = args.matrix_k;
276 uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n
277 ? args.matrix_n
278 : (start_n + wg_tile_n);
279 uint32_t boundary_m = (start_m + wg_tile_m)
280 > (args.matrix_m + batch_id * args.matrix_m)
281 ? (args.matrix_m + batch_id * args.matrix_m)
282 : (start_m + wg_tile_m);
283 uint32_t boundary_k_a = wg_tile_k;
284 uint32_t boundary_k_b = wg_tile_k + batch_id * args.matrix_k;
285
286 uint32_t gemm_slm_base = slm_base;
287 uint32_t gemm_nbarr_base = nbarrier_base;
288 uint32_t epilogue_slm_base = gemm_slm_base + gemm_t::slm_size;
289 uint32_t epilogue_nbarr_base = gemm_nbarr_base + gemm_t::barrier_count;
290
291 // set up arguments
292 work_group_t g;
293 g.init(item.get_local_linear_id());
294 mem_desc_a_t mem_desc_a;
295 mem_desc_b_t mem_desc_b;
296 mem_desc_c_t mem_desc_c;
297 //setup for matA
298 mem_desc_a.init(args.matA_base,
299 {boundary_k_a, boundary_m, args.matA_ld}, {start_k_a, start_m});
300
301 //setup for matB
302 mem_desc_b.init(args.matB_base,
303 {boundary_n, boundary_k_b, args.matB_ld}, {start_n, start_k_b});
304
305 //setup for matC
306 mem_desc_c.init(args.matC_base, {boundary_n, boundary_m, args.matC_ld},
307 {start_n, start_m});
308
309 uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
310 gemm_args_t gemm_args(mem_desc_a, mem_desc_b, inner_loop_count);
311 gemm_t gemm;
312 epilogue_t epilogue;
313
314 matAcc_t matAcc(0);
315 gemm(g, matAcc, gemm_args, gemm_slm_base, gemm_nbarr_base);
316 epilogue(g, matAcc, mem_desc_c, args.epilogue_args, epilogue_slm_base,
317 epilogue_nbarr_base);
318 }
319};
320
322
323} // namespace gpu::xetla::kernel
Definition batch_gemm.hpp:27
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for BATCH_GEMM.
Definition batch_gemm.hpp:266
static bool can_implement(arguments_t &args)
Check if the arguments can be implemented.
Definition batch_gemm.hpp:218
static cl::sycl::nd_range< 3 > get_nd_range(arguments_t &args)
Host helper function to get the expected nd_range under the current BATCH_GEMM config.
Definition batch_gemm.hpp:208
static cl::sycl::range< 3 > get_local_range()
Host helper function to get the expected local range under the current BATCH_GEMM config.
Definition batch_gemm.hpp:183
static __XETLA_API constexpr uint32_t get_barrier_count()
Gets named_barrier id consumption count.
Definition batch_gemm.hpp:163
static __XETLA_API constexpr uint32_t get_slm_size()
Gets local memory size consumption.
Definition batch_gemm.hpp:174
static cl::sycl::range< 3 > get_group_range(uint32_t batch_size, uint32_t matrix_m, uint32_t matrix_n)
Host helper function to get the expected group range under the current BATCH_GEMM config.
Definition batch_gemm.hpp:196
Definition limitation.hpp:738
Definition limitation.hpp:736
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
BATCH_GEMM arguments.
Definition batch_gemm.hpp:71
static constexpr bool host_callable
Set for device copyable.
Definition batch_gemm.hpp:99
uint32_t matrix_n
Is the size of the n dimension of the matrix multiplication (m x k x n).
Definition batch_gemm.hpp:79
uint32_t matrix_m
Is the size of the m dimension of the matrix multiplication (m x k x n).
Definition batch_gemm.hpp:75
uint32_t batch_size
Is the number of total batches.
Definition batch_gemm.hpp:73
uint32_t matA_ld
Is the leading dimension (pitch) size of the matrix A in memory.
Definition batch_gemm.hpp:81
matC_base_t matC_base
Is the base address of matrix C.
Definition batch_gemm.hpp:91
matA_base_t matA_base
Is the base address of matrix A.
Definition batch_gemm.hpp:87
epilogue_args_t epilogue_args
Is the epilogue arguments.
Definition batch_gemm.hpp:93
uint32_t matC_ld
Is the leading dimension (pitch) size of the matrix C in memory.
Definition batch_gemm.hpp:85
uint32_t matrix_k
Is the size of the k dimension of the matrix multiplication (m x k x n).
Definition batch_gemm.hpp:77
matB_base_t matB_base
Is the base address of matrix B.
Definition batch_gemm.hpp:89
uint32_t matB_ld
Is the leading dimension (pitch) size of the matrix B in memory.
Definition batch_gemm.hpp:83
arguments_t(const arguments_t &args)
Definition batch_gemm.hpp:132
arguments_t & operator=(const arguments_t &args)
Definition batch_gemm.hpp:144
arguments_t()=default
Constructs arguments with default method.
arguments_t(uint32_t batch_size_, uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, matA_base_t matA_base_, uint32_t matA_ld_, matB_base_t matB_base_, uint32_t matB_ld_, matC_base_t matC_base_, uint32_t matC_ld_, epilogue_args_t epilogue_args_={})
Constructs arguments with initialization list.
Definition batch_gemm.hpp:113
C++ API.