DPC++ Runtime
Runtime libraries for oneAPI DPC++
dot_product.hpp
Go to the documentation of this file.
1 //==----------- dot_product.hpp ------- SYCL dot-product -------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // DP4A extension
10 
11 #pragma once
12 
14 #include <sycl/types.hpp>
15 
16 namespace sycl {
18 namespace ext::oneapi {
19 
20 union Us {
21  char s[4];
22  int32_t i;
23 };
24 union Uu {
25  unsigned char s[4];
26  uint32_t i;
27 };
28 
29 int32_t dot_acc(int32_t pa, int32_t pb, int32_t c) {
30  Us a = *(reinterpret_cast<Us *>(&pa));
31  Us b = *(reinterpret_cast<Us *>(&pb));
32  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
33  c;
34 }
35 
36 int32_t dot_acc(uint32_t pa, uint32_t pb, int32_t c) {
37  Uu a = *(reinterpret_cast<Uu *>(&pa));
38  Uu b = *(reinterpret_cast<Uu *>(&pb));
39  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
40  c;
41 }
42 
43 int32_t dot_acc(int32_t pa, uint32_t pb, int32_t c) {
44  Us a = *(reinterpret_cast<Us *>(&pa));
45  Uu b = *(reinterpret_cast<Uu *>(&pb));
46  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
47  c;
48 }
49 
50 int32_t dot_acc(uint32_t pa, int32_t pb, int32_t c) {
51  Uu a = *(reinterpret_cast<Uu *>(&pa));
52  Us b = *(reinterpret_cast<Us *>(&pb));
53  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
54  c;
55 }
56 
57 int32_t dot_acc(vec<int8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
58  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
59  c;
60 }
61 
62 int32_t dot_acc(vec<uint8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
63  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
64  c;
65 }
66 
67 int32_t dot_acc(vec<uint8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
68  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
69  c;
70 }
71 
72 int32_t dot_acc(vec<int8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
73  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
74  c;
75 }
76 
77 } // namespace ext::oneapi
78 
79 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
80 } // namespace sycl
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:11
types.hpp
sycl::_V1::ext::oneapi::dot_acc
int32_t dot_acc(vec< int8_t, 4 > a, vec< uint8_t, 4 > b, int32_t c)
Definition: dot_product.hpp:72
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:14
sycl::_V1::ext::oneapi::Uu::i
uint32_t i
Definition: dot_product.hpp:26
sycl::_V1::ext::oneapi::Us
Definition: dot_product.hpp:20
defines_elementary.hpp
sycl::_V1::vec
Provides a cross-patform vector class template that works efficiently on SYCL devices as well as in h...
Definition: aliases.hpp:20
sycl::_V1::ext::oneapi::Us::i
int32_t i
Definition: dot_product.hpp:22
sycl::_V1::ext::oneapi::Uu
Definition: dot_product.hpp:24