XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
data_transformer_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
26
27namespace gpu::xetla::kernel {
28
38template <typename dtype_in_, typename dtype_out_, typename dtype_compute_,
39 typename data_transformer_attr_, mem_layout mem_layout_in_,
40 int need_fp8_op>
41struct xetla_data_transformer<dtype_in_, dtype_out_, dtype_compute_,
42 data_transformer_attr_, mem_layout_in_, need_fp8_op, gpu_arch::Xe> {
43 using dtype_in = dtype_in_;
44 using dtype_out = dtype_out_;
45 using dtype_compute = dtype_compute_;
46 using data_transformer_attr = data_transformer_attr_;
47
48 static constexpr mem_layout mem_layout_in = mem_layout_in_;
49
50 static constexpr bool is_col_major_in
51 = mem_layout_in == mem_layout::col_major;
52
53 static constexpr uint32_t wg_tile_m = data_transformer_attr::wg_tile_m;
54 static constexpr uint32_t wg_tile_n = data_transformer_attr::wg_tile_n;
55 static constexpr uint32_t sg_tile_m = data_transformer_attr::sg_tile_m;
56 static constexpr uint32_t sg_tile_n = data_transformer_attr::sg_tile_n;
57
58 static constexpr uint32_t tile_size_x = sg_tile_n;
59 static constexpr uint32_t tile_size_y = sg_tile_m;
60
61 static constexpr uint32_t wg_size_x
62 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
63 static constexpr uint32_t wg_size_y
64 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
65
66 using load_store_attr = typename arch_attr_t<
68 static constexpr uint32_t max_load_height_in_elem
69 = load_store_attr::max_load_height_in_elem;
70 static constexpr uint32_t max_load_width_in_bytes
71 = load_store_attr::max_load_width_in_bytes;
72 static constexpr uint32_t max_store_width_in_bytes
73 = load_store_attr::max_store_width_in_bytes;
74 static constexpr uint32_t max_trans_block_width
75 = load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype_in);
76 static constexpr uint32_t max_load_width_in_elem
77 = max_load_width_in_bytes / sizeof(dtype_in);
78 static constexpr uint32_t max_store_width_in_elem
79 = max_store_width_in_bytes / sizeof(dtype_out);
80
82 static constexpr uint32_t load_size_x
84 max_load_width_in_elem>::value;
85 static_assert(load_size_x >= 8,
86 "if block_size_x less than 8, the efficiency will be low. Please "
87 "choose another tile_size_x");
88 static constexpr uint32_t st_size_x = max_store_width_in_elem > tile_size_x
89 ? tile_size_x
91 max_store_width_in_elem>::value;
92 static_assert(st_size_x >= 8,
93 "if st_block_size_x less than 8, the efficiency will be "
94 "low. ");
95 static constexpr uint32_t block_size_x
97
98 static constexpr uint32_t block_size_y_limit
99 = is_col_major_in ? max_trans_block_width : max_load_height_in_elem;
100
101 static constexpr uint32_t block_size_y = block_size_y_limit > tile_size_y
102 ? tile_size_y
103 : block_size_y_limit;
104
105 static constexpr reg_layout in_reg_layout = reg_layout::tiled;
106
108 tile_size_y, block_size_x, block_size_y, in_reg_layout>;
113 subgroup::msg_type_v<global_ld_tile_desc_t, mem_space::global>,
115
117 tile_size_y, block_size_x, block_size_y, reg_layout::tiled>;
123 tile_size_y, block_size_x, block_size_y, reg_layout::tiled>;
126
128 = group::group_reduce_t<dtype_compute, tile_size_x * tile_size_y, 1,
129 reduce_op::max, wg_size_x * wg_size_y, true, gpu_arch::Xe>;
130
135 struct arguments_t {
138 uint32_t matrix_m;
139 uint32_t matrix_n;
140 uint32_t matrix_in_ld;
148 };
149
152 struct get_barrier_count {
153 static constexpr uint32_t count
154 = (wg_size_x * wg_size_y > 1) ? wg_size_x * wg_size_y : 0;
155 };
156
159 struct get_slm_size {
160 static constexpr uint32_t size = (wg_size_x * wg_size_y > 1)
161 ? wg_size_x * wg_size_y * sizeof(dtype_compute)
162 : 0;
163 };
164
170 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args) {
171 int tid_x = item.get_local_id(2);
172 int tid_y = item.get_local_id(1);
173 uint32_t sg_id = item.get_local_linear_id();
174
175 global_ld_t mat_global_ld;
176 global_ld_payload_t global_ld_payload;
177 global_st_t mat_global_st;
178 global_st_payload_t global_st_payload;
179 global_compute_t mat_global_compute;
180
181 //input and output starting point
182 int global_ld_start_x;
183 int global_ld_start_y;
184
185 if constexpr (mem_layout_in == mem_layout::row_major) {
186 global_ld_start_x = args->wg_ld_start_x + tid_x * sg_tile_n;
187 global_ld_start_y = args->wg_ld_start_y + tid_y * sg_tile_m;
188 } else {
189 global_ld_start_x = args->wg_ld_start_x + tid_y * sg_tile_m;
190 global_ld_start_y = args->wg_ld_start_y + tid_x * sg_tile_n;
191 }
192
193 int global_st_start_x = args->wg_st_start_x + tid_x * sg_tile_n;
194 int global_st_start_y = args->wg_st_start_y + tid_y * sg_tile_m;
195
196 if constexpr (mem_layout_in == mem_layout::row_major) {
197 global_ld_payload.init(args->mat_in_ptr, args->matrix_n,
198 args->matrix_m, args->matrix_in_ld, global_ld_start_x,
199 global_ld_start_y);
200 } else {
201 global_ld_payload.init(args->mat_in_ptr, args->matrix_m,
202 args->matrix_n, args->matrix_in_ld, global_ld_start_x,
203 global_ld_start_y);
204 }
205
206 global_st_payload.init(args->mat_out_ptr, args->matrix_n,
207 args->matrix_m, args->matrix_out_ld, global_st_start_x,
208 global_st_start_y);
209
210 subgroup::tile_load(mat_global_ld, global_ld_payload);
211
212 if constexpr (need_fp8_op) {
213 subgroup::elemwise_cvt(mat_global_compute, mat_global_ld);
214
215 static constexpr uint32_t simd = 16;
216 uint64_t offset = 0;
217
221 cache_hint::cached>(args->scale, offset);
222
223 mat_global_compute.reg
224 = mat_global_compute.reg * (dtype_compute)(local_scale[0]);
225
226 subgroup::elemwise_cvt(mat_global_st, mat_global_compute);
227
228 wg_reduce_t wg_reduce;
229 wg_reduce.init(sg_id, 0, 0);
230
231 mat_global_compute.reg = xetla_abs<dtype_compute,
232 global_compute_t::tile_desc::tile_elems>(
233 mat_global_compute.reg);
234
236 = wg_reduce(mat_global_compute.reg);
237
238 xetla_mask<simd> pred(0);
239 pred[0] = 1;
240
241 xetla_vector<dtype_compute, simd> local_max(local_wg_max[0]);
243 = xetla_vector_gen<uint32_t, simd>(0, 1);
244
247 atomic_op::fmax>((uint64_t)args->amax_ptr,
248 offsets * sizeof(dtype_compute), local_max, pred);
249 } else {
250 subgroup::elemwise_cvt(mat_global_st, mat_global_ld);
251 }
252
253 subgroup::tile_store<cache_hint::uncached>(
254 mat_global_st, global_st_payload);
255 }
256};
257
258} // namespace gpu::xetla::kernel
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
C++ API.
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< T0, SZ > xetla_abs(xetla_vector< T1, SZ > src0)
Get absolute value (vector version)
Definition math_general.hpp:39
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ fmax
Atomic store the float max of src1 and memory data and return the old value. see
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
This is the group reduction.
Definition reduction_api.hpp:36
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args)
Main execution function for data_transformer.
Definition data_transformer_xe.hpp:170
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition data_transformer_xe.hpp:67
Is the data_transformer functor.
Definition api.hpp:38
Definition memory_descriptor.hpp:139
Definition common.hpp:80
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102