DPC++ Runtime
Runtime libraries for oneAPI DPC++
imf_half_trivial.hpp
Go to the documentation of this file.
1 //==------------- imf_half_trivial.hpp - trivial half utils ----------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Trival half util functions.
9 //===----------------------------------------------------------------------===//
10 
11 #pragma once
12 #include <sycl/half_type.hpp>
13 
14 namespace sycl {
16 namespace ext::intel::math {
17 sycl::half hadd(sycl::half x, sycl::half y) { return x + y; }
18 
20  return sycl::clamp((x + y), sycl::half(0.f), sycl::half(1.0f));
21 }
22 
24  return sycl::fma(x, y, z);
25 }
26 
28  return sycl::clamp(sycl::fma(x, y, z), sycl::half(0.f), sycl::half(1.0f));
29 }
30 
31 sycl::half hmul(sycl::half x, sycl::half y) { return x * y; }
32 
34  return sycl::clamp((x * y), sycl::half(0.f), sycl::half(1.0f));
35 }
36 
37 sycl::half hneg(sycl::half x) { return -x; }
38 
39 sycl::half hsub(sycl::half x, sycl::half y) { return x - y; }
40 
42  return sycl::clamp((x - y), sycl::half(0.f), sycl::half(1.0f));
43 }
44 
45 sycl::half hdiv(sycl::half x, sycl::half y) { return x / y; }
46 
47 bool heq(sycl::half x, sycl::half y) { return x == y; }
48 
50  if (sycl::isnan(x) || sycl::isnan(y))
51  return true;
52  else
53  return x == y;
54 }
55 
56 bool hge(sycl::half x, sycl::half y) { return x >= y; }
57 
59  if (sycl::isnan(x) || sycl::isnan(y))
60  return true;
61  else
62  return x >= y;
63 }
64 
65 bool hgt(sycl::half x, sycl::half y) { return x > y; }
66 
68  if (sycl::isnan(x) || sycl::isnan(y))
69  return true;
70  else
71  return x > y;
72 }
73 
74 bool hle(sycl::half x, sycl::half y) { return x <= y; }
75 
77  if (sycl::isnan(x) || sycl::isnan(y))
78  return true;
79  else
80  return x <= y;
81 }
82 
83 bool hlt(sycl::half x, sycl::half y) { return x < y; }
84 
86  if (sycl::isnan(x) || sycl::isnan(y))
87  return true;
88  return x < y;
89 }
90 
92  if (sycl::isnan(x) || sycl::isnan(y))
93  return false;
94  return x != y;
95 }
96 
98  if (sycl::isnan(x) || sycl::isnan(y))
99  return true;
100  else
101  return x != y;
102 }
103 
104 bool hisinf(sycl::half x) { return sycl::isinf(x); }
105 bool hisnan(sycl::half y) { return sycl::isnan(y); }
106 
107 sycl::half2 hadd2(sycl::half2 x, sycl::half2 y) { return x + y; }
108 
109 sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y) {
110  return sycl::clamp((x + y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
111 }
112 
113 sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
114  return sycl::fma(x, y, z);
115 }
116 
117 sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
118  return sycl::clamp(sycl::fma(x, y, z), sycl::half2{0.f, 0.f},
119  sycl::half2{1.f, 1.f});
120 }
121 
122 sycl::half2 hmul2(sycl::half2 x, sycl::half2 y) { return x * y; }
123 
124 sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y) {
125  return sycl::clamp((x * y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
126 }
127 
128 sycl::half2 h2div(sycl::half2 x, sycl::half2 y) { return x / y; }
129 
130 sycl::half2 hneg2(sycl::half2 x) { return -x; }
131 
132 sycl::half2 hsub2(sycl::half2 x, sycl::half2 y) { return x - y; }
133 
134 sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y) {
135  return sycl::clamp((x - y), sycl::half2{0.f, 0.f}, sycl::half2{1.f, 1.f});
136 }
137 
138 bool hbeq2(sycl::half2 x, sycl::half2 y) {
139  return heq(x.s0(), y.s0()) && heq(x.s1(), y.s1());
140 }
141 
142 bool hbequ2(sycl::half2 x, sycl::half2 y) {
143  return hequ(x.s0(), y.s0()) && hequ(x.s1(), y.s1());
144 }
145 
146 bool hbge2(sycl::half2 x, sycl::half2 y) {
147  return hge(x.s0(), y.s0()) && hge(x.s1(), y.s1());
148 }
149 
150 bool hbgeu2(sycl::half2 x, sycl::half2 y) {
151  return hgeu(x.s0(), y.s0()) && hgeu(x.s1(), y.s1());
152 }
153 
154 bool hbgt2(sycl::half2 x, sycl::half2 y) {
155  return hgt(x.s0(), y.s0()) && hgt(x.s1(), y.s1());
156 }
157 
158 bool hbgtu2(sycl::half2 x, sycl::half2 y) {
159  return hgtu(x.s0(), y.s0()) && hgtu(x.s1(), y.s1());
160 }
161 
162 bool hble2(sycl::half2 x, sycl::half2 y) {
163  return hle(x.s0(), y.s0()) && hle(x.s1(), y.s1());
164 }
165 
166 bool hbleu2(sycl::half2 x, sycl::half2 y) {
167  return hleu(x.s0(), y.s0()) && hleu(x.s1(), y.s1());
168 }
169 
170 bool hblt2(sycl::half2 x, sycl::half2 y) {
171  return hlt(x.s0(), y.s0()) && hlt(x.s1(), y.s1());
172 }
173 
174 bool hbltu2(sycl::half2 x, sycl::half2 y) {
175  return hltu(x.s0(), y.s0()) && hltu(x.s1(), y.s1());
176 }
177 
178 bool hbne2(sycl::half2 x, sycl::half2 y) {
179  return hne(x.s0(), y.s0()) && hne(x.s1(), y.s1());
180 }
181 
182 bool hbneu2(sycl::half2 x, sycl::half2 y) {
183  return hneu(x.s0(), y.s0()) && hneu(x.s1(), y.s1());
184 }
185 
186 sycl::half2 heq2(sycl::half2 x, sycl::half2 y) {
187  return sycl::half2{(heq(x.s0(), y.s0()) ? 1.0f : 0.f),
188  (heq(x.s1(), y.s1()) ? 1.0f : 0.f)};
189 }
190 
191 sycl::half2 hequ2(sycl::half2 x, sycl::half2 y) {
192  return sycl::half2{(hequ(x.s0(), y.s0()) ? 1.0f : 0.f),
193  (hequ(x.s1(), y.s1()) ? 1.0f : 0.f)};
194 }
195 
196 sycl::half2 hge2(sycl::half2 x, sycl::half2 y) {
197  return sycl::half2{(hge(x.s0(), y.s0()) ? 1.0f : 0.f),
198  (hge(x.s1(), y.s1()) ? 1.0f : 0.f)};
199 }
200 
201 sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y) {
202  return sycl::half2{(hgeu(x.s0(), y.s0()) ? 1.0f : 0.f),
203  (hgeu(x.s1(), y.s1()) ? 1.0f : 0.f)};
204 }
205 
206 sycl::half2 hgt2(sycl::half2 x, sycl::half2 y) {
207  return sycl::half2{(hgt(x.s0(), y.s0()) ? 1.0f : 0.f),
208  (hgt(x.s1(), y.s1()) ? 1.0f : 0.f)};
209 }
210 
211 sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y) {
212  return sycl::half2{(hgtu(x.s0(), y.s0()) ? 1.0f : 0.f),
213  (hgtu(x.s1(), y.s1()) ? 1.0f : 0.f)};
214 }
215 
216 sycl::half2 hle2(sycl::half2 x, sycl::half2 y) {
217  return sycl::half2{(hle(x.s0(), y.s0()) ? 1.0f : 0.f),
218  (hle(x.s1(), y.s1()) ? 1.0f : 0.f)};
219 }
220 
221 sycl::half2 hleu2(sycl::half2 x, sycl::half2 y) {
222  return sycl::half2{(hleu(x.s0(), y.s0()) ? 1.0f : 0.f),
223  (hleu(x.s1(), y.s1()) ? 1.0f : 0.f)};
224 }
225 
226 sycl::half2 hlt2(sycl::half2 x, sycl::half2 y) {
227  return sycl::half2{(hlt(x.s0(), y.s0()) ? 1.0f : 0.f),
228  (hlt(x.s1(), y.s1()) ? 1.0f : 0.f)};
229 }
230 
231 sycl::half2 hltu2(sycl::half2 x, sycl::half2 y) {
232  return sycl::half2{(hltu(x.s0(), y.s0()) ? 1.0f : 0.f),
233  (hltu(x.s1(), y.s1()) ? 1.0f : 0.f)};
234 }
235 
236 sycl::half2 hisnan2(sycl::half2 x) {
237  return sycl::half2{(hisnan(x.s0()) ? 1.0f : 0.f),
238  (hisnan(x.s1()) ? 1.0f : 0.f)};
239 }
240 
241 sycl::half2 hne2(sycl::half2 x, sycl::half2 y) {
242  return sycl::half2{(hne(x.s0(), y.s0()) ? 1.0f : 0.f),
243  (hne(x.s1(), y.s1()) ? 1.0f : 0.f)};
244 }
245 
246 sycl::half2 hneu2(sycl::half2 x, sycl::half2 y) {
247  return sycl::half2{(hneu(x.s0(), y.s0()) ? 1.0f : 0.f),
248  (hneu(x.s1(), y.s1()) ? 1.0f : 0.f)};
249 }
250 
251 sycl::half hmax(sycl::half x, sycl::half y) { return sycl::fmax(x, y); }
252 
254  if (hisnan(x) || hisnan(y))
255  return sycl::half(NAN);
256  else
257  return sycl::fmax(x, y);
258 }
259 
260 sycl::half2 hmax2(sycl::half2 x, sycl::half2 y) {
261  return sycl::half2{hmax(x.s0(), y.s0()), hmax(x.s1(), y.s1())};
262 }
263 
264 sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y) {
265  return sycl::half2{hmax_nan(x.s0(), y.s0()), hmax_nan(x.s1(), y.s1())};
266 }
267 
268 sycl::half hmin(sycl::half x, sycl::half y) { return sycl::fmin(x, y); }
269 
271  if (hisnan(x) || hisnan(y))
272  return sycl::half(NAN);
273  else
274  return sycl::fmin(x, y);
275 }
276 
277 sycl::half2 hmin2(sycl::half2 x, sycl::half2 y) {
278  return sycl::half2{hmin(x.s0(), y.s0()), hmin(x.s1(), y.s1())};
279 }
280 
281 sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y) {
282  return sycl::half2{hmin_nan(x.s0(), y.s0()), hmin_nan(x.s1(), y.s1())};
283 }
284 
285 sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
286  return sycl::half2{x.s0() * y.s0() - x.s1() * y.s1() + z.s0(),
287  x.s0() * y.s1() + x.s1() * y.s0() + z.s1()};
288 }
289 
291  sycl::half r = sycl::fma(x, y, z);
292  if (!hisnan(r)) {
293  if (r < 0.f)
294  return sycl::half{0.f};
295  else
296  return r;
297  }
298  return r;
299 }
300 
301 sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z) {
302  sycl::half2 r = sycl::fma(x, y, z);
303  if (!hisnan(r.s0()) && r.s0() < 0.f)
304  r.s0() = 0.f;
305  if (!hisnan(r.s1()) && r.s1() < 0.f)
306  r.s1() = 0.f;
307  return r;
308 }
309 
311 
312 sycl::half2 habs2(sycl::half2 x) { return sycl::fabs(x); }
313 } // namespace ext::intel::math
314 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
315 } // namespace sycl
#define __SYCL_INLINE_VER_NAMESPACE(X)
sycl::half habs(sycl::half x)
sycl::half2 hfma2_relu(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half hfma_sat(sycl::half x, sycl::half y, sycl::half z)
bool hbleu2(sycl::half2 x, sycl::half2 y)
bool hbgtu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hlt2(sycl::half2 x, sycl::half2 y)
bool hequ(sycl::half x, sycl::half y)
sycl::half hfma(sycl::half x, sycl::half y, sycl::half z)
sycl::half2 hne2(sycl::half2 x, sycl::half2 y)
bool hleu(sycl::half x, sycl::half y)
sycl::half hmul(sycl::half x, sycl::half y)
sycl::half hmul_sat(sycl::half x, sycl::half y)
sycl::half2 hmin2(sycl::half2 x, sycl::half2 y)
sycl::half2 hadd2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgeu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hmax2(sycl::half2 x, sycl::half2 y)
bool hble2(sycl::half2 x, sycl::half2 y)
sycl::half2 hisnan2(sycl::half2 x)
bool hbneu2(sycl::half2 x, sycl::half2 y)
bool hbgt2(sycl::half2 x, sycl::half2 y)
sycl::half2 hsub2(sycl::half2 x, sycl::half2 y)
sycl::half hsub_sat(sycl::half x, sycl::half y)
bool hbltu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hfma2(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half2 hmin2_nan(sycl::half2 x, sycl::half2 y)
bool hgt(sycl::half x, sycl::half y)
sycl::half2 hneg2(sycl::half2 x)
sycl::half hmax_nan(sycl::half x, sycl::half y)
sycl::half hsub(sycl::half x, sycl::half y)
sycl::half2 hfma2_sat(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half2 hmul2(sycl::half2 x, sycl::half2 y)
sycl::half2 hleu2(sycl::half2 x, sycl::half2 y)
sycl::half hmax(sycl::half x, sycl::half y)
bool hbequ2(sycl::half2 x, sycl::half2 y)
bool heq(sycl::half x, sycl::half y)
sycl::half2 hle2(sycl::half2 x, sycl::half2 y)
bool hltu(sycl::half x, sycl::half y)
sycl::half2 hequ2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgtu2(sycl::half2 x, sycl::half2 y)
sycl::half hadd_sat(sycl::half x, sycl::half y)
sycl::half2 hltu2(sycl::half2 x, sycl::half2 y)
sycl::half2 hgt2(sycl::half2 x, sycl::half2 y)
sycl::half2 hsub2_sat(sycl::half2 x, sycl::half2 y)
bool hle(sycl::half x, sycl::half y)
sycl::half hneg(sycl::half x)
bool hbge2(sycl::half2 x, sycl::half2 y)
sycl::half hfma_relu(sycl::half x, sycl::half y, sycl::half z)
sycl::half2 hmax2_nan(sycl::half2 x, sycl::half2 y)
bool hgtu(sycl::half x, sycl::half y)
sycl::half2 hge2(sycl::half2 x, sycl::half2 y)
sycl::half2 h2div(sycl::half2 x, sycl::half2 y)
bool hlt(sycl::half x, sycl::half y)
bool hgeu(sycl::half x, sycl::half y)
sycl::half hmin_nan(sycl::half x, sycl::half y)
sycl::half2 hmul2_sat(sycl::half2 x, sycl::half2 y)
bool hblt2(sycl::half2 x, sycl::half2 y)
bool hbeq2(sycl::half2 x, sycl::half2 y)
sycl::half hadd(sycl::half x, sycl::half y)
sycl::half2 hadd2_sat(sycl::half2 x, sycl::half2 y)
sycl::half hmin(sycl::half x, sycl::half y)
bool hbne2(sycl::half2 x, sycl::half2 y)
sycl::half2 hcmadd(sycl::half2 x, sycl::half2 y, sycl::half2 z)
sycl::half hdiv(sycl::half x, sycl::half y)
sycl::half2 heq2(sycl::half2 x, sycl::half2 y)
sycl::half2 hneu2(sycl::half2 x, sycl::half2 y)
bool hne(sycl::half x, sycl::half y)
sycl::half2 habs2(sycl::half2 x)
bool hge(sycl::half x, sycl::half y)
bool hneu(sycl::half x, sycl::half y)
bool hbgeu2(sycl::half2 x, sycl::half2 y)
std::enable_if_t< std::is_same< T, bfloat16 >::value, bool > isnan(T x)
std::enable_if_t< detail::is_bf16_storage_type< T >::value, T > fabs(T x)
sycl::detail::half_impl::half half
Definition: aliases.hpp:99
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14