XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
math_general.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
25#pragma clang diagnostic push
26#pragma clang diagnostic ignored "-Wunused-parameter"
27namespace gpu::xetla {
28
31
38template <typename T0, typename T1, int SZ>
41 "The internal types are not yet supported!");
42 return __ESIMD_NS::abs<T0, T1, SZ>(src0);
43}
44
50template <typename T0, typename T1>
51std::enable_if_t<!std::is_same<remove_const_t<T0>, remove_const_t<T1>>::value,
55 "The internal types are not yet supported!");
56 return __ESIMD_NS::abs<T0, T1>(src0);
57}
58
66template <typename T1, int SZ>
68 static_assert(!(is_internal_type<T1>::value),
69 "The internal types are not yet supported!");
70 return __ESIMD_NS::abs<T1, SZ>(src0);
71}
72
79template <typename T1>
80__XETLA_API typename std::remove_const<T1>::type xetla_abs(T1 src0) {
81 static_assert(!(is_internal_type<T1>::value),
82 "The internal types are not yet supported!");
83 return __ESIMD_NS::abs<T1>(src0);
84}
85
95
96template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
98 xetla_vector<T, SZ> src0, xetla_vector<T, SZ> src1, Sat sat = {}) {
99 static_assert(!(is_internal_type<T>::value),
100 "The internal types are not yet supported!");
101 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
102}
103
114template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
116 xetla_vector<T, SZ> src0, T src1, Sat sat = {}) {
117 static_assert(!(is_internal_type<T>::value),
118 "The internal types are not yet supported!");
119 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
120}
121
132template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
134 T src0, xetla_vector<T, SZ> src1, Sat sat = {}) {
135 static_assert(!(is_internal_type<T>::value),
136 "The internal types are not yet supported!");
137 return __ESIMD_NS::max<T, SZ>(src0, src1, Sat::value);
138}
139
148template <typename T, typename Sat = xetla_saturation_off_tag>
149__XETLA_API T xetla_max(T src0, T src1, Sat sat = {}) {
150 static_assert(!(is_internal_type<T>::value),
151 "The internal types are not yet supported!");
152 return __ESIMD_NS::max<T>(src0, src1, Sat::value);
153}
154
164
165template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
167 xetla_vector<T, SZ> src0, xetla_vector<T, SZ> src1, Sat sat = {}) {
168 static_assert(!(is_internal_type<T>::value),
169 "The internal types are not yet supported!");
170 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
171}
172
183template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
185 xetla_vector<T, SZ> src0, T src1, Sat sat = {}) {
186 static_assert(!(is_internal_type<T>::value),
187 "The internal types are not yet supported!");
188 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
189}
190
201template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
203 T src0, xetla_vector<T, SZ> src1, Sat sat = {}) {
204 static_assert(!(is_internal_type<T>::value),
205 "The internal types are not yet supported!");
206 return __ESIMD_NS::min<T, SZ>(src0, src1, Sat::value);
207}
208
217template <typename T, typename Sat = xetla_saturation_off_tag>
218__XETLA_API T xetla_min(T src0, T src1, Sat sat = {}) {
219 static_assert(!(is_internal_type<T>::value),
220 "The internal types are not yet supported!");
221 return __ESIMD_NS::min<T>(src0, src1, Sat::value);
222}
223
231template <class T, int SZ, typename Sat = xetla_saturation_off_tag>
233 xetla_vector<T, SZ> src, Sat sat = {}) {
234 static_assert((std::is_same<remove_const_t<T>, float>::value)
235 || (std::is_same<remove_const_t<T>, fp16>::value),
236 "Only support fp32 and fp16");
237 return __ESIMD_NS::exp<T, SZ>(src, Sat::value);
238}
239
246template <class T, typename Sat = xetla_saturation_off_tag>
247__XETLA_API T xetla_exp(T src, Sat sat = {}) {
248 static_assert((std::is_same<remove_const_t<T>, float>::value)
249 || (std::is_same<remove_const_t<T>, fp16>::value),
250 "Only support fp32 and fp16");
251 return __ESIMD_NS::exp<T>(src, Sat::value);
252}
253
261template <class T, int SZ, typename Sat = xetla_saturation_off_tag>
263 xetla_vector<T, SZ> src, Sat sat = {}) {
264 static_assert((std::is_same<remove_const_t<T>, float>::value)
265 || (std::is_same<remove_const_t<T>, fp16>::value),
266 "Only support fp32 and fp16");
267 return __ESIMD_NS::exp2<T, SZ>(src, Sat::value);
268}
269
276template <class T, typename Sat = xetla_saturation_off_tag>
277__XETLA_API T xetla_exp2(T src, Sat sat = {}) {
278 static_assert((std::is_same<remove_const_t<T>, float>::value)
279 || (std::is_same<remove_const_t<T>, fp16>::value),
280 "Only support fp32 and fp16");
281 return __ESIMD_NS::exp2<T>(src, Sat::value);
282}
283
290template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
292 xetla_vector<T, SZ> src, Sat sat = {}) {
293 static_assert((std::is_same<remove_const_t<T>, float>::value)
294 || (std::is_same<remove_const_t<T>, fp16>::value),
295 "Only support fp32 and fp16");
296 return __ESIMD_NS::inv(src, Sat::value);
297}
298
304template <typename T, typename Sat = xetla_saturation_off_tag>
305__XETLA_API T xetla_inv(T src, Sat sat = {}) {
306 static_assert((std::is_same<remove_const_t<T>, float>::value)
307 || (std::is_same<remove_const_t<T>, fp16>::value),
308 "Only support fp32 and fp16");
309 return __ESIMD_NS::inv(src, Sat::value);
310}
311
318template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
320 xetla_vector<T, SZ> src, Sat sat = {}) {
321 static_assert((std::is_same<remove_const_t<T>, float>::value)
322 || (std::is_same<remove_const_t<T>, fp16>::value),
323 "Only support fp32 and fp16");
324 return __ESIMD_NS::sqrt(src, Sat::value);
325}
326
332template <typename T, typename Sat = xetla_saturation_off_tag>
333__XETLA_API T xetla_sqrt(T src, Sat sat = {}) {
334 static_assert((std::is_same<remove_const_t<T>, float>::value)
335 || (std::is_same<remove_const_t<T>, fp16>::value),
336 "Only support fp32 and fp16");
337 return __ESIMD_NS::sqrt(src, Sat::value);
338}
339
346template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
348 xetla_vector<T, SZ> src, Sat sat = {}) {
349 static_assert((std::is_same<remove_const_t<T>, float>::value)
350 || (std::is_same<remove_const_t<T>, double>::value),
351 "Only support fp32 and fp16");
352 return __ESIMD_NS::sqrt_ieee(src, Sat::value);
353}
354
360template <typename T, typename Sat = xetla_saturation_off_tag>
361__XETLA_API T xetla_sqrt_ieee(T src, Sat sat = {}) {
362 static_assert((std::is_same<remove_const_t<T>, float>::value)
363 || (std::is_same<remove_const_t<T>, double>::value),
364 "Only support fp32 and fp16");
365 return __ESIMD_NS::sqrt_ieee(src, Sat::value);
366}
367
374template <typename T, int SZ, typename Sat = xetla_saturation_off_tag>
376 xetla_vector<T, SZ> src, Sat sat = {}) {
377 static_assert((std::is_same<remove_const_t<T>, float>::value)
378 || (std::is_same<remove_const_t<T>, fp16>::value),
379 "Only support fp32 and fp16");
380 return __ESIMD_NS::rsqrt(src, Sat::value);
381}
382
388template <typename T, typename Sat = xetla_saturation_off_tag>
389__XETLA_API T xetla_rsqrt(T src, Sat sat = {}) {
390 static_assert((std::is_same<remove_const_t<T>, float>::value)
391 || (std::is_same<remove_const_t<T>, fp16>::value),
392 "Only support fp32 and fp16");
393 return __ESIMD_NS::rsqrt(src, Sat::value);
394}
395
401template <typename T, int SZ>
403 static_assert(std::is_same<remove_const_t<T>, float>::value,
404 "Only support fp32! ");
405 constexpr uint32_t flag_elems = 8 * 16;
407 if constexpr (SZ / flag_elems > 0) {
408#pragma unroll
409 for (uint32_t i = 0; i < SZ / flag_elems; i++) {
410 auto src_sub = src.xetla_select<flag_elems, 1>(i * flag_elems);
411 auto ret_sub = ret.xetla_select<flag_elems, 1>(i * flag_elems);
412 xetla_mask<flag_elems> mask = src_sub >= 10;
414 = xetla_exp<T, flag_elems>(src_sub * 2.f);
415 ret_sub = (exp2x - 1.f) / (exp2x + 1.f);
417 ret_sub.xetla_merge(ones, mask);
418 }
419 }
420
421 if constexpr (SZ % flag_elems != 0) {
422 constexpr uint32_t start_pos = SZ / flag_elems * flag_elems;
423 constexpr uint32_t remain_elems = SZ % flag_elems;
424
425 auto src_sub = src.xetla_select<remain_elems, 1>(start_pos);
426 auto ret_sub = ret.xetla_select<remain_elems, 1>(start_pos);
427 xetla_mask<remain_elems> mask = src_sub >= 10;
429 = xetla_exp<T, remain_elems>(src_sub * 2.f);
430 ret_sub = (exp2x - 1.f) / (exp2x + 1.f);
432 ret_sub.xetla_merge(ones, mask);
433 }
434
435 return ret;
436}
437
443template <typename T>
445 static_assert(std::is_same<remove_const_t<T>, float>::value,
446 "Only support fp32! ");
447 T exp2x = xetla_exp<T>(src * 2.f);
448 T ret = (exp2x - 1.f) / (exp2x + 1.f);
449 return (src >= 10) ? 1 : ret;
450}
451
459template <typename T, int SZ>
462 static_assert((std::is_same<remove_const_t<T>, uint32_t>::value),
463 "For addc, only uint32_t is supported");
464 xetla_vector<T, SZ> carry_tmp;
465 xetla_vector<T, SZ> out = __ESIMD_ENS::addc(carry_tmp, src0, src1);
466 carry = carry_tmp;
467 return out;
468}
469
477template <typename T, int SZ>
480 static_assert((std::is_same<remove_const_t<T>, uint32_t>::value),
481 "For addc, only uint32_t is supported");
482 xetla_vector<T, SZ> carry_tmp;
483 xetla_vector<T, SZ> out = __ESIMD_ENS::addc(carry_tmp, src0, src1);
484 carry = carry_tmp;
485 return out;
486}
487
497template <typename T0, typename T1, typename T2, int SZ>
499 xetla_vector<T1, SZ> src0, T2 src1) {
502 = __ESIMD_ENS::imul<T0, T1, T2, SZ>(lo_tmp, src0, src1);
503 lo = lo_tmp;
504 return hi_tmp;
505}
506
519template <typename T0, typename T1, int SZ, reduce_op BinaryOperation>
521 if constexpr (BinaryOperation == reduce_op::sum) {
522 return __ESIMD_NS::detail::sum<T0, T1, SZ>(v);
523 } else if constexpr (BinaryOperation == reduce_op::prod) {
524 return __ESIMD_NS::detail::prod<T0, T1, SZ>(v);
525 } else if constexpr (BinaryOperation == reduce_op::min) {
526 return __ESIMD_NS::hmin<T0, T1, SZ>(v);
527 } else if constexpr (BinaryOperation == reduce_op::max) {
528 return __ESIMD_NS::hmax<T0, T1, SZ>(v);
529 }
530}
531
537template <typename T, int SZ>
539 static_assert(!(is_internal_type<T>::value),
540 "The internal types are not yet supported!");
541 return __ESIMD_NS::rnde<T, SZ>(src0);
542}
543
554template <typename T1, typename T0, int SZ,
555 typename Sat = xetla_saturation_off_tag>
557 xetla_vector<T0, SZ> src0, xetla_vector<T0, SZ> src1, Sat sat = {}) {
558 static_assert(
560 "The internal types are not yet supported!");
562 xetla_vector<T0, SZ> temp = src0 + src1;
563 if constexpr (std::is_same_v<Sat, xetla_saturation_off_tag>)
564 return xetla_vector<T1, SZ>(temp);
565 else
566 return __ESIMD_NS::saturate<T1, T0, SZ>(temp);
567}
568
573template <typename T1, typename T0, int SZ>
575 static_assert(
577 "The internal types are not yet supported!");
578 return __ESIMD_NS::saturate<T1, T0, SZ>(src);
579}
580
582
583} // namespace gpu::xetla
584#pragma clang diagnostic pop
C++ API.
C++ API.
typename std::remove_const< T >::type remove_const_t
Definition common.hpp:26
#define __XETLA_API
Definition common.hpp:43
Workaround for ESIMD vector(1D) ref type.
Definition base_types.hpp:187
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
sycl::half fp16
xetla fp16 data type.
Definition base_types.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API T0 xetla_reduce(xetla_vector< T1, SZ > v)
Performs reduction over elements of the input vector.
Definition math_general.hpp:520
__XETLA_API xetla_vector< T0, SZ > xetla_abs(xetla_vector< T1, SZ > src0)
Get absolute value (vector version)
Definition math_general.hpp:39
__XETLA_API xetla_vector< T, SZ > xetla_max(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the maximum of the two vectors.
Definition math_general.hpp:97
__XETLA_API xetla_vector< T, SZ > xetla_exp(xetla_vector< T, SZ > src, Sat sat={})
Calculate exponent value for each element of the input vector, the base is e.
Definition math_general.hpp:232
__XETLA_API xetla_vector< T, SZ > xetla_min(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, Sat sat={})
Selects component-wise the minimum of the two vectors.
Definition math_general.hpp:166
__XETLA_API xetla_vector< T1, SZ > xetla_sat(xetla_vector< T0, SZ > src)
Saturation function.
Definition math_general.hpp:574
__XETLA_API xetla_vector< T, SZ > xetla_sqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the square root, i.e.
Definition math_general.hpp:319
__XETLA_API xetla_vector< T0, SZ > xetla_imul(xetla_vector_ref< T0, SZ > __REF__ lo, xetla_vector< T1, SZ > src0, T2 src1)
Multiply src0 with src1, return the hi part and in-place update the lo part.
Definition math_general.hpp:498
__XETLA_API xetla_vector< T, SZ > xetla_exp2(xetla_vector< T, SZ > src, Sat sat={})
Calculate exponent value for each element of the input vector, the base is 2.
Definition math_general.hpp:262
__XETLA_API xetla_vector< T, SZ > xetla_tanh(xetla_vector< T, SZ > src)
Calculate the tanh (vector version).
Definition math_general.hpp:402
__XETLA_API xetla_vector< T, SZ > xetla_rsqrt(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion of square root, i.e.
Definition math_general.hpp:375
__XETLA_API xetla_vector< T, SZ > xetla_add_c(xetla_vector< T, SZ > src0, xetla_vector< T, SZ > src1, xetla_vector_ref< T, SZ > __REF__ carry)
Add two unsigned integer vectors, return the result and in-place update the carry.
Definition math_general.hpp:460
__XETLA_API xetla_vector< T, SZ > xetla_inv(xetla_vector< T, SZ > src, Sat sat={})
Calculate the inversion, i.e.
Definition math_general.hpp:291
__XETLA_API xetla_vector< T, SZ > xetla_sqrt_ieee(xetla_vector< T, SZ > src, Sat sat={})
Calculate the square root, i.e.
Definition math_general.hpp:347
__XETLA_API xetla_vector< T1, SZ > xetla_add(xetla_vector< T0, SZ > src0, xetla_vector< T0, SZ > src1, Sat sat={})
Adds two vectors with saturation The source operands must be both of floating-point type.
Definition math_general.hpp:556
__XETLA_API xetla_vector< T, SZ > xetla_rnde(xetla_vector< T, SZ > src0)
Get rounded value.
Definition math_general.hpp:538
Definition arch_config.hpp:24
Used to check if the type is xetla internal data type.
Definition base_types.hpp:67
static constexpr bool value
Definition base_types.hpp:68