DPC++ Runtime
Runtime libraries for oneAPI DPC++
dot_product.hpp
Go to the documentation of this file.
1 //==----------- dot_product.hpp ------- SYCL dot-product -------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // DP4A extension
10 
11 #pragma once
12 
14 #include <sycl/types.hpp>
15 
16 namespace sycl {
18 namespace ext {
19 namespace oneapi {
20 
21 union Us {
22  char s[4];
23  int32_t i;
24 };
25 union Uu {
26  unsigned char s[4];
27  uint32_t i;
28 };
29 
30 int32_t dot_acc(int32_t pa, int32_t pb, int32_t c) {
31  Us a = *(reinterpret_cast<Us *>(&pa));
32  Us b = *(reinterpret_cast<Us *>(&pb));
33  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
34  c;
35 }
36 
37 int32_t dot_acc(uint32_t pa, uint32_t pb, int32_t c) {
38  Uu a = *(reinterpret_cast<Uu *>(&pa));
39  Uu b = *(reinterpret_cast<Uu *>(&pb));
40  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
41  c;
42 }
43 
44 int32_t dot_acc(int32_t pa, uint32_t pb, int32_t c) {
45  Us a = *(reinterpret_cast<Us *>(&pa));
46  Uu b = *(reinterpret_cast<Uu *>(&pb));
47  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
48  c;
49 }
50 
51 int32_t dot_acc(uint32_t pa, int32_t pb, int32_t c) {
52  Uu a = *(reinterpret_cast<Uu *>(&pa));
53  Us b = *(reinterpret_cast<Us *>(&pb));
54  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
55  c;
56 }
57 
58 int32_t dot_acc(vec<int8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
59  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
60  c;
61 }
62 
63 int32_t dot_acc(vec<uint8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
64  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
65  c;
66 }
67 
68 int32_t dot_acc(vec<uint8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
69  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
70  c;
71 }
72 
73 int32_t dot_acc(vec<int8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
74  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
75  c;
76 }
77 
78 } // namespace oneapi
79 } // namespace ext
80 
81 } // __SYCL_INLINE_VER_NAMESPACE(_V1)
82 } // namespace sycl
__SYCL_INLINE_VER_NAMESPACE
#define __SYCL_INLINE_VER_NAMESPACE(X)
Definition: defines_elementary.hpp:13
types.hpp
sycl::_V1::ext::oneapi::dot_acc
int32_t dot_acc(vec< int8_t, 4 > a, vec< uint8_t, 4 > b, int32_t c)
Definition: dot_product.hpp:73
sycl
---— Error handling, matching OpenCL plugin semantics.
Definition: access.hpp:13
sycl::_V1::ext::oneapi::Uu::i
uint32_t i
Definition: dot_product.hpp:27
sycl::_V1::ext::oneapi::Us
Definition: dot_product.hpp:21
defines_elementary.hpp
sycl::_V1::vec
Provides a cross-patform vector class template that works efficiently on SYCL devices as well as in h...
Definition: aliases.hpp:19
sycl::_V1::ext::oneapi::Us::i
int32_t i
Definition: dot_product.hpp:23
sycl::_V1::ext::oneapi::Uu
Definition: dot_product.hpp:25