XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
misc.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24__XETLA_API constexpr uint32_t div_round_up(uint32_t n, uint32_t d) {
25 return (n + d - 1) / d;
26}
27
28//Rounds number down towards the next lowest number
29//e.g. -2.0/3.0 ~ -0.666 -> -1.
30__XETLA_API constexpr int div_round_down(int n, int d) {
31
32 return (n - (((n % d) + d) % d)) / d;
33}
34
35//Calculate modulo based on definition that uses floored divison.
36//Result has the same sign as d.
37__XETLA_API constexpr int modulo(int n, int d) {
38 return (d + (n % d)) % d;
39}
40
41//Pad the given allocation size upto nearest cacheline
42__XETLA_API constexpr uint32_t cacheline_align_up(size_t size) {
43
44 const int CACHELINE_SIZE = 256;
45 return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE;
46}
47
48namespace gpu::xetla {
49
52
58 xetla_vector<uint32_t, 4> time_stamp = 0;
59 return time_stamp;
60}
61
72template <typename Ty, int N>
74 xetla_vector<Ty, N> tmp(InitVal, Step);
75 return tmp;
76}
77
78template <uint32_t N>
81 tmp = sycl::ext::intel::esimd::unpack_mask<N>(mask_val);
82 return tmp;
83}
84
85template <typename dtype_acc, uint32_t N, uint32_t num_flag = 4,
86 typename dtype_mask = uint8_t>
88 xetla_vector<dtype_mask, N> mask, dtype_acc scale) {
89 xetla_vector<dtype_acc, N> out = in * scale;
90 constexpr uint32_t unroll_size = num_flag * 16;
91 SW_BARRIER();
92#pragma unroll
93 for (uint32_t i = 0; i < N / unroll_size; i++) {
95 = mask.xetla_select<unroll_size, 1>(i * unroll_size) > 0;
96 out.xetla_select<unroll_size, 1>(i * unroll_size)
97 .xetla_merge(0, mask_flag);
98 }
99 if constexpr (N % unroll_size != 0) {
100 constexpr uint32_t remain_len = N % unroll_size;
101 constexpr uint32_t remain_start = N / unroll_size * unroll_size;
102 xetla_mask<remain_len> mask_flag
103 = mask.xetla_select<remain_len, 1>(remain_start) > 0;
104 out.xetla_select<remain_len, 1>(remain_start).xetla_merge(0, mask_flag);
105 }
106 return out;
107}
108
109template <reduce_op reduce_kind, typename dtype, int size>
110__XETLA_API typename std::enable_if_t<reduce_kind == reduce_op::sum,
111 xetla_vector<dtype, size>>
113 return a + b;
114}
115
116template <reduce_op reduce_kind, typename dtype, int size>
117__XETLA_API typename std::enable_if_t<reduce_kind == reduce_op::prod,
118 xetla_vector<dtype, size>>
120 return a * b;
121}
122
123template <reduce_op reduce_kind, typename dtype, int size>
124__XETLA_API typename std::enable_if_t<reduce_kind == reduce_op::max,
125 xetla_vector<dtype, size>>
128 xetla_mask<size> mask = a > b;
129 out.xetla_merge(a, b, mask);
130 return out;
131}
132
133template <reduce_op reduce_kind, typename dtype, int size>
134__XETLA_API typename std::enable_if_t<reduce_kind == reduce_op::min,
135 xetla_vector<dtype, size>>
138 xetla_mask<size> mask = a < b;
139 out.xetla_merge(a, b, mask);
140 return out;
141}
142
143template <reduce_op reduce_kind, typename dtype, int N_x, int N_y>
144__XETLA_API typename std::enable_if_t<N_y == 1, xetla_vector<dtype, N_x>>
146 return in;
147}
148template <reduce_op reduce_kind, typename dtype, int N_x, int N_y>
149__XETLA_API typename std::enable_if_t<(N_y > 1), xetla_vector<dtype, N_x>>
151 static_assert(((N_y) & (N_y - 1)) == 0, "N_y should be power of 2");
152 xetla_vector<dtype, N_x * N_y / 2> temp;
153 temp = reduce_helper<reduce_kind, dtype, N_x * N_y / 2>(
154 in.xetla_select<N_x * N_y / 2, 1>(0),
155 in.xetla_select<N_x * N_y / 2, 1>(N_x * N_y / 2));
156
157 return recur_row_reduce<reduce_kind, dtype, N_x, N_y / 2>(temp);
158}
159
160template <reduce_op reduce_kind, typename dtype, int N_x, int N_y>
161__XETLA_API typename std::enable_if_t<N_x == 1, xetla_vector<dtype, N_y>>
163 return in;
164}
165template <reduce_op reduce_kind, typename dtype, int N_x, int N_y>
166__XETLA_API typename std::enable_if_t<(N_x > 1), xetla_vector<dtype, N_y>>
168 static_assert(((N_x) & (N_x - 1)) == 0, "N_x should be power of 2");
169 xetla_vector<dtype, N_x * N_y / 2> temp;
170 auto in_2d = in.xetla_format<dtype, N_y, N_x>();
171 temp = reduce_helper<reduce_kind, dtype, N_y * N_x / 2>(
172 in_2d.xetla_select<N_y, 1, N_x / 2, 1>(0, 0),
173 in_2d.xetla_select<N_y, 1, N_x / 2, 1>(0, N_x / 2));
174
175 return recur_col_reduce<reduce_kind, dtype, N_x / 2, N_y>(temp);
176}
177
180__XETLA_API uint32_t get_2d_group_linear_id(sycl::nd_item<3> &item) {
181 return item.get_group(2) + item.get_group(1) * item.get_group_range(2);
182}
184
185} // namespace gpu::xetla
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
C++ API.
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask_int
wrapper for xetla_mask_int.
Definition base_types.hpp:172
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API std::enable_if_t< N_x==1, xetla_vector< dtype, N_y > > recur_col_reduce(xetla_vector< dtype, N_y > in)
Definition misc.hpp:162
__XETLA_API xetla_vector< dtype_acc, N > drop_out(xetla_vector< dtype_acc, N > in, xetla_vector< dtype_mask, N > mask, dtype_acc scale)
Definition misc.hpp:87
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
__XETLA_API std::enable_if_t< N_y==1, xetla_vector< dtype, N_x > > recur_row_reduce(xetla_vector< dtype, N_x > in)
Definition misc.hpp:145
__XETLA_API xetla_mask_int< N > xetla_mask_int_gen(uint32_t mask_val)
Definition misc.hpp:79
__XETLA_API xetla_vector< Ty, N > xetla_vector_gen(int InitVal, int Step)
xetla_vector generation.
Definition misc.hpp:73
__XETLA_API uint32_t get_2d_group_linear_id(sycl::nd_item< 3 > &item)
get linear group id of the last two dimensions.
Definition misc.hpp:180
Definition arch_config.hpp:24
__XETLA_API constexpr int div_round_down(int n, int d)
Definition misc.hpp:30
__XETLA_API constexpr int modulo(int n, int d)
Definition misc.hpp:37
__XETLA_API constexpr uint32_t div_round_up(uint32_t n, uint32_t d)
Definition misc.hpp:24
__XETLA_API constexpr uint32_t cacheline_align_up(size_t size)
Definition misc.hpp:42