XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
common.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "common/common.hpp"
23
24namespace gpu::xetla::subgroup {
25
26namespace detail {
27
34template <uint32_t N, uint32_t K, bool K_gt_eq_N>
36
41template <uint32_t N, uint32_t K>
42struct NextPowerOf2<N, K, true> {
43 static constexpr uint32_t get() { return K; }
44};
45
50template <uint32_t N, uint32_t K>
51struct NextPowerOf2<N, K, false> {
52 static constexpr uint32_t get() {
53 return NextPowerOf2<N, K * 2, K * 2 >= N>::get();
54 }
55};
56
61template <uint32_t N>
62constexpr uint32_t getNextPowerOf2() {
63 return NextPowerOf2<N, 1, (1 >= N)>::get();
64}
65
70template <>
71constexpr uint32_t getNextPowerOf2<0>() {
72 return 0;
73}
74
79template <uint32_t a, uint32_t b>
80struct gcd {
81 static constexpr uint32_t value = gcd<b, a % b>::value;
82};
83
87template <uint32_t a>
88struct gcd<a, 0> {
89 static constexpr uint32_t value = a;
90};
91
92enum class process_flag : uint8_t { load = 0, store = 1 };
93
94template <uint32_t remained_len, uint32_t base_len, process_flag flag,
95 cache_hint L1, cache_hint L2, typename payload_t, typename tile_t>
96__XETLA_API typename std::enable_if_t<base_len == 0> process_1d_tail(
97 [[maybe_unused]] tile_t &tile, [[maybe_unused]] payload_t &payload,
98 [[maybe_unused]] uint32_t offset) {}
99
100template <uint32_t remained_len, uint32_t base_len, process_flag flag,
101 cache_hint L1, cache_hint L2, typename payload_t, typename tile_t>
102__XETLA_API typename std::enable_if_t<base_len != 0
103 && payload_t::memory_space == mem_space::global>
104process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset) {
105 using dtype = typename payload_t::dtype;
106 using mem_dtype = typename payload_t::mem_dtype;
107 if constexpr (remained_len >= base_len) {
108 uint32_t address_offset = offset * sizeof(dtype);
109 auto reg_sub
110 = tile.reg.xetla_select<base_len * payload_t::scale_factor, 1>(
111 offset);
112 if constexpr (flag == process_flag::load) {
113 reg_sub.xetla_format<mem_dtype>() = xetla_load_global<mem_dtype,
114 base_len, data_size::default_size, L1, L2>(
115 payload.base_ptr, payload.base_offset + address_offset);
116 } else {
117 xetla_store_global<mem_dtype, base_len, data_size::default_size, L1,
118 L2>(payload.base_ptr, payload.base_offset + address_offset,
119 reg_sub.xetla_format<mem_dtype>());
120 }
121 process_1d_tail<remained_len - base_len, (base_len >> 1), flag, L1, L2>(
122 tile, payload, offset + base_len * payload_t::scale_factor);
123 } else {
124 process_1d_tail<remained_len, (base_len >> 1), flag, L1, L2>(
125 tile, payload, offset);
126 }
127}
128
129template <uint32_t remained_len, uint32_t base_len, process_flag flag,
130 cache_hint L1, cache_hint L2, typename payload_t, typename tile_t>
131__XETLA_API typename std::enable_if_t<base_len != 0
132 && payload_t::memory_space == mem_space::local>
133process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset) {
134 using mem_dtype = typename payload_t::mem_dtype;
135 if constexpr (remained_len >= base_len) {
136 auto reg_sub
137 = tile.reg.xetla_select<base_len * payload_t::scale_factor, 1>(
138 offset);
139 uint32_t address_offset = offset * sizeof(typename tile_t::dtype);
140 if constexpr (flag == process_flag::load) {
141 reg_sub.xetla_format<mem_dtype>() = xetla_load_local<mem_dtype,
142 base_len, data_size::default_size>(
143 payload.address + address_offset);
144 } else {
145 xetla_store_local<mem_dtype, base_len>(
146 payload.address + address_offset,
147 reg_sub.xetla_format<mem_dtype>());
148 }
149 process_1d_tail<remained_len - base_len, (base_len >> 1), flag, L1, L2,
150 payload_t, tile_t>(
151 tile, payload, offset + base_len * payload_t::scale_factor);
152 } else {
153 process_1d_tail<remained_len, (base_len >> 1), flag, L1, L2, payload_t,
154 tile_t>(tile, payload, offset);
155 }
156}
157
158// This will end up with base_len equal to 8 because we had made tile_size_x
159// divisible by 8/16/32, depends on dtype
160template <uint32_t remained_len, uint32_t base_len, cache_hint L1,
161 cache_hint L2, typename payload_t>
162__XETLA_API typename std::enable_if_t<(base_len < 8)> process_1d_tail(
163 payload_t &payload, uint32_t offset) {
164 using dtype = typename payload_t::dtype;
165 using prefetch_dtype = typename payload_t::prefetch_dtype;
166 uint32_t address_offset = offset * sizeof(dtype);
167 constexpr uint32_t prefetch_min_size = 64 / sizeof(prefetch_dtype);
168 if constexpr (remained_len > 0) {
169 xetla_prefetch_global<prefetch_dtype, prefetch_min_size,
171 payload.base_ptr, payload.base_offset + address_offset);
172 }
173}
174
175template <uint32_t remained_len, uint32_t base_len, cache_hint L1,
176 cache_hint L2, typename payload_t>
177__XETLA_API typename std::enable_if_t<base_len >= 8> process_1d_tail(
178 payload_t &payload, uint32_t offset) {
179 using dtype = typename payload_t::dtype;
180 using prefetch_dtype = typename payload_t::prefetch_dtype;
181 if constexpr (remained_len >= base_len) {
182 uint32_t address_offset = offset * sizeof(dtype);
183 xetla_prefetch_global<prefetch_dtype, base_len, data_size::default_size,
184 L1, L2>(payload.base_ptr, payload.base_offset + address_offset);
185 process_1d_tail<remained_len - base_len, (base_len >> 1), L1, L2,
186 payload_t>(
187 payload, offset + base_len * payload_t::scale_factor);
188 } else {
189 process_1d_tail<remained_len, (base_len >> 1), L1, L2, payload_t>(
190 payload, offset);
191 }
192}
193
194template <uint32_t num_tdesc, uint32_t size_x, uint32_t size_y,
195 uint32_t scale_factor, uint8_t arr_len, bool trans>
196__XETLA_API static void reset_tile_desc_core(
198#pragma unroll
199 for (uint32_t j = 0; j < num_tdesc; j++) {
200 constexpr uint8_t block_width
201 = trans ? (size_y / scale_factor) : (size_x / scale_factor);
202 constexpr uint8_t block_height = trans ? size_x : size_y;
203 constexpr uint32_t block_widthx_widthy_arrlen = (block_width - 1)
204 | ((block_height - 1) << 8) | ((arr_len - 1) << 16);
206 payload_row.row(j), block_widthx_widthy_arrlen);
207 }
208}
209
210} // namespace detail
211
212template <typename T_dst, typename T_src>
214 static constexpr bool value = (T_src::block_size_y == T_dst::block_size_y)
215 && (T_src::block_size_x == T_dst::block_size_x)
216 && (T_src::tile_size_y == T_dst::tile_size_y)
217 && (T_src::tile_size_x == T_dst::tile_size_x);
218};
219
220template <typename T_dst, typename T_src>
222 static constexpr bool value
225};
226
227template <typename tile_desc_, mem_space memory_space,
228 mem_layout memory_layout = mem_layout::row_major>
230 static constexpr msg_type value = memory_space == mem_space::global
231 ? (((tile_desc_::tile_size_y == 1)
232 && (memory_layout == mem_layout::row_major))
235 : (((tile_desc_::tile_size_y == 1)
236 && (memory_layout == mem_layout::row_major))
238 : msg_type::scatter);
239};
240
241template <typename tile_desc_, mem_space memory_space>
243
244template <typename dtype, uint32_t tile_size_x, uint32_t tile_size_y,
245 gpu_arch arch_tag, mem_layout mem_layout_ = mem_layout::row_major,
246 reg_layout reg_layout_ = reg_layout::tiled>
248
249template <typename dtype, uint32_t tile_size_x, uint32_t tile_size_y>
250struct get_load_block_size_auto<dtype, tile_size_x, tile_size_y, gpu_arch::Xe,
252private:
253 using load_store_attr = arch_attr_t<gpu_arch::Xe>::template load_store_attr<
255 static constexpr uint32_t max_load_height_in_elem
256 = load_store_attr::max_load_height_in_elem;
257 static constexpr uint32_t max_load_width_in_bytes
258 = load_store_attr::max_load_width_in_bytes;
259 static constexpr uint32_t max_load_width_in_elem
260 = max_load_width_in_bytes / sizeof(dtype);
261
262public:
263 // block_size_x should be power of 2 and tile_size_x should be divided by block_size_x
264 static constexpr uint32_t block_size_x
266 static constexpr uint32_t block_size_y
267 = max_load_height_in_elem > tile_size_y ? tile_size_y
268 : max_load_height_in_elem;
269};
270
271template <typename dtype, uint32_t tile_size_x, uint32_t tile_size_y,
272 gpu_arch arch_tag, mem_layout mem_layout_ = mem_layout::row_major,
273 reg_layout reg_layout_ = reg_layout::tiled>
275
276template <typename dtype, uint32_t tile_size_x, uint32_t tile_size_y>
277struct get_store_block_size_auto<dtype, tile_size_x, tile_size_y, gpu_arch::Xe,
279private:
280 using load_store_attr = arch_attr_t<gpu_arch::Xe>::template load_store_attr<
282 static constexpr uint32_t max_store_height_in_elem
283 = load_store_attr::max_store_height_in_elem;
284 static constexpr uint32_t max_store_width_in_bytes
285 = load_store_attr::max_store_width_in_bytes;
286 static constexpr uint32_t max_store_width_in_elem
287 = max_store_width_in_bytes / sizeof(dtype);
288
289public:
290 // block_size_x should be power of 2 and tile_size_x should be divided by block_size_x
291 static constexpr uint32_t block_size_x
293 static constexpr uint32_t block_size_y
294 = max_store_height_in_elem > tile_size_y ? tile_size_y
295 : max_store_height_in_elem;
296};
297
298// This type tag represents "global atomic oob check on" behavior
299struct global_atomic_oob_check_on_tag : std::true_type {};
300
301// This type tag represents "global atomic oob check off" behavior
302struct global_atomic_oob_check_off_tag : std::false_type {};
303
304} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
Workaround for ESIMD matrix(2D) ref type.
Definition base_types.hpp:202
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
__XETLA_API void xetla_prefetch_global(Ty *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred=1)
Stateless scattered prefetch.
Definition memory.hpp:187
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_local(xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred=1)
SLM scattered load.
Definition memory.hpp:464
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
__XETLA_API void xetla_set_block_widthx_widthy_arrlen(xetla_tdescriptor_ref desc, uint32_t block_widthx_widthy_arrlen)
Definition tensor_descriptor.hpp:79
__XETLA_API uint32_t uint32_t size_y
Definition common.hpp:194
process_flag
Definition common.hpp:92
constexpr uint32_t getNextPowerOf2()
Get the Next Power Of2 object.
Definition common.hpp:62
__XETLA_API uint32_t uint32_t uint32_t uint8_t arr_len
Definition common.hpp:195
__XETLA_API uint32_t uint32_t uint32_t scale_factor
Definition common.hpp:195
__XETLA_API std::enable_if_t< base_len==0 > process_1d_tail(tile_t &tile, payload_t &payload, uint32_t offset)
Definition common.hpp:96
constexpr uint32_t getNextPowerOf2< 0 >()
Get the Next Power Of2<0> object.
Definition common.hpp:71
__XETLA_API uint32_t size_x
Definition common.hpp:194
Definition limitation.hpp:457
constexpr msg_type msg_type_v
Definition common.hpp:242
cache_hint
L1 or L2 cache hint kinds.
Definition common.hpp:89
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ tile
flush out to the local scope
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
Used to check if the type is floating_point.
Definition base_types.hpp:75
Used to check if the type is floating_point.
Definition base_types.hpp:86
static constexpr uint32_t get()
Definition common.hpp:52
static constexpr uint32_t get()
Definition common.hpp:43
Compute next power of 2 of a constexpr with guaranteed compile-time evaluation.
Definition common.hpp:35
Definition common.hpp:80
static constexpr uint32_t value
Definition common.hpp:81
static constexpr bool value
Definition common.hpp:223
Definition common.hpp:213
static constexpr bool value
Definition common.hpp:214
Definition common.hpp:229
static constexpr msg_type value
Definition common.hpp:230
Is a struct contains some register file.
Definition api.hpp:99
dtype_ dtype
Definition api.hpp:100