DPC++ Runtime
Runtime libraries for oneAPI DPC++
imf_half_trivial.hpp
Go to the documentation of this file.
1 //==------------- imf_half_trivial.hpp - trivial half utils ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Trival half util functions.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 
13 #include <sycl/builtins.hpp>
14 #include <sycl/half_type.hpp>
15 
16 namespace sycl {
17 inline namespace _V1 {
18 namespace ext::intel::math {
19 sycl::half hadd(sycl::half x, sycl::half y) { return x + y; }
20 
22  return sycl::clamp((x + y), sycl::half(0.f), sycl::half(1.0f));
23 }
24 
26  return sycl::fma(x, y, z);
27 }
28 
30  return sycl::clamp(sycl::fma(x, y, z), sycl::half(0.f), sycl::half(1.0f));
31 }
32 
33 sycl::half hmul(sycl::half x, sycl::half y) { return x * y; }
34 
36  return sycl::clamp((x * y), sycl::half(0.f), sycl::half(1.0f));
37 }
38 
39 sycl::half hneg(sycl::half x) { return -x; }
40 
41 sycl::half hsub(sycl::half x, sycl::half y) { return x - y; }
42 
44  return sycl::clamp((x - y), sycl::half(0.f), sycl::half(1.0f));
45 }
46 
47 sycl::half hdiv(sycl::half x, sycl::half y) { return x / y; }
48 
49 bool heq(sycl::half x, sycl::half y) { return x == y; }
50 
52  if (sycl::isnan(x) || sycl::isnan(y))
53  return true;
54  else
55  return x == y;
56 }
57 
58 bool hge(sycl::half x, sycl::half y) { return x >= y; }
59 
61  if (sycl::isnan(x) || sycl::isnan(y))
62  return true;
63  else
64  return x >= y;
65 }
66 
67 bool hgt(sycl::half x, sycl::half y) { return x > y; }
68 
70  if (sycl::isnan(x) || sycl::isnan(y))
71  return true;
72  else
73  return x > y;
74 }
75 
76 bool hle(sycl::half x, sycl::half y) { return x <= y; }
77 
79  if (sycl::isnan(x) || sycl::isnan(y))
80  return true;
81  else
82  return x <= y;
83 }
84 
85 bool hlt(sycl::half x, sycl::half y) { return x < y; }
86 
88  if (sycl::isnan(x) || sycl::isnan(y))
89  return true;
90  return x < y;
91 }
92 
94  if (sycl::isnan(x) || sycl::isnan(y))
95  return false;
96  return x != y;
97 }
98 
100  if (sycl::isnan(x) || sycl::isnan(y))
101  return true;
102  else
103  return x != y;
104 }
105 
106 bool hisinf(sycl::half x) { return sycl::isinf(x); }
107 bool hisnan(sycl::half y) { return sycl::isnan(y); }
108 
109 sycl::half2 hadd2(sycl::half2 x, sycl::half2 y) { return x + y; }
110 
111 sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y) {
112  return sycl::clamp((x + y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
113 }
114 
115 sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
116  return sycl::fma(x, y, z);
117 }
118 
119 sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
120  return sycl::clamp(sycl::fma(x, y, z), sycl::half2{0.f, 0.f},
121  sycl::half2{1.f, 1.f});
122 }
123 
124 sycl::half2 hmul2(sycl::half2 x, sycl::half2 y) { return x * y; }
125 
126 sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y) {
127  return sycl::clamp((x * y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
128 }
129 
130 sycl::half2 h2div(sycl::half2 x, sycl::half2 y) { return x / y; }
131 
132 sycl::half2 hneg2(sycl::half2 x) { return -x; }
133 
134 sycl::half2 hsub2(sycl::half2 x, sycl::half2 y) { return x - y; }
135 
136 sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y) {
137  return sycl::clamp((x - y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
138 }
139 
140 bool hbeq2(sycl::half2 x, sycl::half2 y) {
141  return heq(x.s0(), y.s0()) && heq(x.s1(), y.s1());
142 }
143 
144 bool hbequ2(sycl::half2 x, sycl::half2 y) {
145  return hequ(x.s0(), y.s0()) && hequ(x.s1(), y.s1());
146 }
147 
148 bool hbge2(sycl::half2 x, sycl::half2 y) {
149  return hge(x.s0(), y.s0()) && hge(x.s1(), y.s1());
150 }
151 
152 bool hbgeu2(sycl::half2 x, sycl::half2 y) {
153  return hgeu(x.s0(), y.s0()) && hgeu(x.s1(), y.s1());
154 }
155 
156 bool hbgt2(sycl::half2 x, sycl::half2 y) {
157  return hgt(x.s0(), y.s0()) && hgt(x.s1(), y.s1());
158 }
159 
160 bool hbgtu2(sycl::half2 x, sycl::half2 y) {
161  return hgtu(x.s0(), y.s0()) && hgtu(x.s1(), y.s1());
162 }
163 
164 bool hble2(sycl::half2 x, sycl::half2 y) {
165  return hle(x.s0(), y.s0()) && hle(x.s1(), y.s1());
166 }
167 
168 bool hbleu2(sycl::half2 x, sycl::half2 y) {
169  return hleu(x.s0(), y.s0()) && hleu(x.s1(), y.s1());
170 }
171 
172 bool hblt2(sycl::half2 x, sycl::half2 y) {
173  return hlt(x.s0(), y.s0()) && hlt(x.s1(), y.s1());
174 }
175 
176 bool hbltu2(sycl::half2 x, sycl::half2 y) {
177  return hltu(x.s0(), y.s0()) && hltu(x.s1(), y.s1());
178 }
179 
180 bool hbne2(sycl::half2 x, sycl::half2 y) {
181  return hne(x.s0(), y.s0()) && hne(x.s1(), y.s1());
182 }
183 
184 bool hbneu2(sycl::half2 x, sycl::half2 y) {
185  return hneu(x.s0(), y.s0()) && hneu(x.s1(), y.s1());
186 }
187 
188 sycl::half2 heq2(sycl::half2 x, sycl::half2 y) {
189  return sycl::half2{(heq(x.s0(), y.s0()) ? 1.0f : 0.f),
190  (heq(x.s1(), y.s1()) ? 1.0f : 0.f)};
191 }
192 
193 sycl::half2 hequ2(sycl::half2 x, sycl::half2 y) {
194  return sycl::half2{(hequ(x.s0(), y.s0()) ? 1.0f : 0.f),
195  (hequ(x.s1(), y.s1()) ? 1.0f : 0.f)};
196 }
197 
198 sycl::half2 hge2(sycl::half2 x, sycl::half2 y) {
199  return sycl::half2{(hge(x.s0(), y.s0()) ? 1.0f : 0.f),
200  (hge(x.s1(), y.s1()) ? 1.0f : 0.f)};
201 }
202 
203 sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y) {
204  return sycl::half2{(hgeu(x.s0(), y.s0()) ? 1.0f : 0.f),
205  (hgeu(x.s1(), y.s1()) ? 1.0f : 0.f)};
206 }
207 
208 sycl::half2 hgt2(sycl::half2 x, sycl::half2 y) {
209  return sycl::half2{(hgt(x.s0(), y.s0()) ? 1.0f : 0.f),
210  (hgt(x.s1(), y.s1()) ? 1.0f : 0.f)};
211 }
212 
213 sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y) {
214  return sycl::half2{(hgtu(x.s0(), y.s0()) ? 1.0f : 0.f),
215  (hgtu(x.s1(), y.s1()) ? 1.0f : 0.f)};
216 }
217 
218 sycl::half2 hle2(sycl::half2 x, sycl::half2 y) {
219  return sycl::half2{(hle(x.s0(), y.s0()) ? 1.0f : 0.f),
220  (hle(x.s1(), y.s1()) ? 1.0f : 0.f)};
221 }
222 
223 sycl::half2 hleu2(sycl::half2 x, sycl::half2 y) {
224  return sycl::half2{(hleu(x.s0(), y.s0()) ? 1.0f : 0.f),
225  (hleu(x.s1(), y.s1()) ? 1.0f : 0.f)};
226 }
227 
228 sycl::half2 hlt2(sycl::half2 x, sycl::half2 y) {
229  return sycl::half2{(hlt(x.s0(), y.s0()) ? 1.0f : 0.f),
230  (hlt(x.s1(), y.s1()) ? 1.0f : 0.f)};
231 }
232 
233 sycl::half2 hltu2(sycl::half2 x, sycl::half2 y) {
234  return sycl::half2{(hltu(x.s0(), y.s0()) ? 1.0f : 0.f),
235  (hltu(x.s1(), y.s1()) ? 1.0f : 0.f)};
236 }
237 
238 sycl::half2 hisnan2(sycl::half2 x) {
239  return sycl::half2{(hisnan(x.s0()) ? 1.0f : 0.f),
240  (hisnan(x.s1()) ? 1.0f : 0.f)};
241 }
242 
243 sycl::half2 hne2(sycl::half2 x, sycl::half2 y) {
244  return sycl::half2{(hne(x.s0(), y.s0()) ? 1.0f : 0.f),
245  (hne(x.s1(), y.s1()) ? 1.0f : 0.f)};
246 }
247 
248 sycl::half2 hneu2(sycl::half2 x, sycl::half2 y) {
249  return sycl::half2{(hneu(x.s0(), y.s0()) ? 1.0f : 0.f),
250  (hneu(x.s1(), y.s1()) ? 1.0f : 0.f)};
251 }
252 
253 sycl::half hmax(sycl::half x, sycl::half y) { return sycl::fmax(x, y); }
254 
256  if (hisnan(x) || hisnan(y))
257  return sycl::half(std::numeric_limits<float>::quiet_NaN());
258  else
259  return sycl::fmax(x, y);
260 }
261 
262 sycl::half2 hmax2(sycl::half2 x, sycl::half2 y) {
263  return sycl::half2{hmax(x.s0(), y.s0()), hmax(x.s1(), y.s1())};
264 }
265 
266 sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y) {
267  return sycl::half2{hmax_nan(x.s0(), y.s0()), hmax_nan(x.s1(), y.s1())};
268 }
269 
270 sycl::half hmin(sycl::half x, sycl::half y) { return sycl::fmin(x, y); }
271 
273  if (hisnan(x) || hisnan(y))
274  return sycl::half(std::numeric_limits<float>::quiet_NaN());
275  else
276  return sycl::fmin(x, y);
277 }
278 
279 sycl::half2 hmin2(sycl::half2 x, sycl::half2 y) {
280  return sycl::half2{hmin(x.s0(), y.s0()), hmin(x.s1(), y.s1())};
281 }
282 
283 sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y) {
284  return sycl::half2{hmin_nan(x.s0(), y.s0()), hmin_nan(x.s1(), y.s1())};
285 }
286 
287 sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
288  return sycl::half2{x.s0() * y.s0() - x.s1() * y.s1() + z.s0(),
289  x.s0() * y.s1() + x.s1() * y.s0() + z.s1()};
290 }
291 
293  sycl::half r = sycl::fma(x, y, z);
294  if (!hisnan(r)) {
295  if (r < 0.f)
296  return sycl::half{0.f};
297  else
298  return r;
299  }
300  return r;
301 }
302 
303 sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
304  sycl::half2 r = sycl::fma(x, y, z);
305  if (!hisnan(r.s0()) && r.s0() < 0.f)
306  r.s0() = 0.f;
307  if (!hisnan(r.s1()) && r.s1() < 0.f)
308  r.s1() = 0.f;
309  return r;
310 }
311 
313 
314 sycl::half2 habs2(sycl::half2 x) { return sycl::fabs(x); }
315 } // namespace ext::intel::math
316 } // namespace _V1
317 } // namespace sycl
sycl::half habs(sycl::half x)
sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half hfma_sat(sycl::half x, sycl::half y, sycl::half z)
bool hbleu2(sycl::half2 x, sycl::half2 y)
bool hbgtu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hlt2(sycl::half2 x, sycl::half2 y)
bool hequ(sycl::half x, sycl::half y)
sycl::half hfma(sycl::half x, sycl::half y, sycl::half z)
sycl::half2 hne2(sycl::half2 x, sycl::half2 y)
bool hleu(sycl::half x, sycl::half y)
sycl::half hmul(sycl::half x, sycl::half y)
sycl::half hmul_sat(sycl::half x, sycl::half y)
sycl::half2 hmin2(sycl::half2 x, sycl::half2 y)
sycl::half2 hadd2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hmax2(sycl::half2 x, sycl::half2 y)
bool hble2(sycl::half2 x, sycl::half2 y)
sycl::half2 hisnan2(sycl::half2 x)
bool hbneu2(sycl::half2 x, sycl::half2 y)
bool hbgt2(sycl::half2 x, sycl::half2 y)
sycl::half2 hsub2(sycl::half2 x, sycl::half2 y)
sycl::half hsub_sat(sycl::half x, sycl::half y)
bool hbltu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y)
bool hgt(sycl::half x, sycl::half y)
sycl::half2 hneg2(sycl::half2 x)
sycl::half hmax_nan(sycl::half x, sycl::half y)
sycl::half hsub(sycl::half x, sycl::half y)
sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half2 hmul2(sycl::half2 x, sycl::half2 y)
sycl::half2 hleu2(sycl::half2 x, sycl::half2 y)
sycl::half hmax(sycl::half x, sycl::half y)
bool hbequ2(sycl::half2 x, sycl::half2 y)
bool heq(sycl::half x, sycl::half y)
sycl::half2 hle2(sycl::half2 x, sycl::half2 y)
bool hltu(sycl::half x, sycl::half y)
sycl::half2 hequ2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y)
sycl::half hadd_sat(sycl::half x, sycl::half y)
sycl::half2 hltu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgt2(sycl::half2 x, sycl::half2 y)
sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y)
bool hle(sycl::half x, sycl::half y)
sycl::half hneg(sycl::half x)
bool hbge2(sycl::half2 x, sycl::half2 y)
sycl::half hfma_relu(sycl::half x, sycl::half y, sycl::half z)
sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y)
bool hgtu(sycl::half x, sycl::half y)
sycl::half2 hge2(sycl::half2 x, sycl::half2 y)
sycl::half2 h2div(sycl::half2 x, sycl::half2 y)
bool hlt(sycl::half x, sycl::half y)
bool hgeu(sycl::half x, sycl::half y)
sycl::half hmin_nan(sycl::half x, sycl::half y)
sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y)
bool hblt2(sycl::half2 x, sycl::half2 y)
bool hbeq2(sycl::half2 x, sycl::half2 y)
sycl::half hadd(sycl::half x, sycl::half y)
sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y)
sycl::half hmin(sycl::half x, sycl::half y)
bool hbne2(sycl::half2 x, sycl::half2 y)
sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half hdiv(sycl::half x, sycl::half y)
sycl::half2 heq2(sycl::half2 x, sycl::half2 y)
sycl::half2 hneu2(sycl::half2 x, sycl::half2 y)
bool hne(sycl::half x, sycl::half y)
sycl::half2 habs2(sycl::half2 x)
bool hge(sycl::half x, sycl::half y)
bool hneu(sycl::half x, sycl::half y)
bool hbgeu2(sycl::half2 x, sycl::half2 y)
std::enable_if_t< std::is_same_v< T, bfloat16 >, bool > isnan(T x)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
auto auto autodecltype(x) z
sycl::detail::half_impl::half half
Definition: aliases.hpp:101
autodecltype(x) x
Definition: access.hpp:18