DPC++ Runtime
Runtime libraries for oneAPI DPC++
dot_product.hpp
Go to the documentation of this file.
1 //==----------- dot_product.hpp ------- SYCL dot-product -------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // DP4A extension
10 
11 #pragma once
12 
14 #include <sycl/types.hpp>
15 
16 namespace sycl {
17 inline namespace _V1 {
18 namespace ext::oneapi {
19 
20 union Us {
21  char s[4];
22  int32_t i;
23 };
24 union Uu {
25  unsigned char s[4];
26  uint32_t i;
27 };
28 
29 int32_t dot_acc(int32_t pa, int32_t pb, int32_t c) {
30  Us a = *(reinterpret_cast<Us *>(&pa));
31  Us b = *(reinterpret_cast<Us *>(&pb));
32  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
33  c;
34 }
35 
36 int32_t dot_acc(uint32_t pa, uint32_t pb, int32_t c) {
37  Uu a = *(reinterpret_cast<Uu *>(&pa));
38  Uu b = *(reinterpret_cast<Uu *>(&pb));
39  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
40  c;
41 }
42 
43 int32_t dot_acc(int32_t pa, uint32_t pb, int32_t c) {
44  Us a = *(reinterpret_cast<Us *>(&pa));
45  Uu b = *(reinterpret_cast<Uu *>(&pb));
46  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
47  c;
48 }
49 
50 int32_t dot_acc(uint32_t pa, int32_t pb, int32_t c) {
51  Uu a = *(reinterpret_cast<Uu *>(&pa));
52  Us b = *(reinterpret_cast<Us *>(&pb));
53  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
54  c;
55 }
56 
57 int32_t dot_acc(vec<int8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
58  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
59  c;
60 }
61 
62 int32_t dot_acc(vec<uint8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
63  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
64  c;
65 }
66 
67 int32_t dot_acc(vec<uint8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
68  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
69  c;
70 }
71 
72 int32_t dot_acc(vec<int8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
73  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
74  c;
75 }
76 
77 } // namespace ext::oneapi
78 
79 } // namespace _V1
80 } // namespace sycl
class sycl::vec ///////////////////////// Provides a cross-patform vector class template that works e...
Definition: vector.hpp:361
int32_t dot_acc(int32_t pa, int32_t pb, int32_t c)
Definition: dot_product.hpp:29
auto autodecltype(a) b
Definition: access.hpp:18