DPC++ Runtime
Runtime libraries for oneAPI DPC++
dot_product.hpp
Go to the documentation of this file.
1 //==----------- dot_product.hpp ------- SYCL dot-product -------------------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // DP4A extension
10 
11 #pragma once
12 
14 namespace sycl {
15 namespace ext {
16 namespace oneapi {
17 
18 union Us {
19  char s[4];
20  int32_t i;
21 };
22 union Uu {
23  unsigned char s[4];
24  uint32_t i;
25 };
26 
27 int32_t dot_acc(int32_t pa, int32_t pb, int32_t c) {
28  Us a = *(reinterpret_cast<Us *>(&pa));
29  Us b = *(reinterpret_cast<Us *>(&pb));
30  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
31  c;
32 }
33 
34 int32_t dot_acc(uint32_t pa, uint32_t pb, int32_t c) {
35  Uu a = *(reinterpret_cast<Uu *>(&pa));
36  Uu b = *(reinterpret_cast<Uu *>(&pb));
37  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
38  c;
39 }
40 
41 int32_t dot_acc(int32_t pa, uint32_t pb, int32_t c) {
42  Us a = *(reinterpret_cast<Us *>(&pa));
43  Uu b = *(reinterpret_cast<Uu *>(&pb));
44  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
45  c;
46 }
47 
48 int32_t dot_acc(uint32_t pa, int32_t pb, int32_t c) {
49  Uu a = *(reinterpret_cast<Uu *>(&pa));
50  Us b = *(reinterpret_cast<Us *>(&pb));
51  return a.s[0] * b.s[0] + a.s[1] * b.s[1] + a.s[2] * b.s[2] + a.s[3] * b.s[3] +
52  c;
53 }
54 
55 int32_t dot_acc(vec<int8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
56  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
57  c;
58 }
59 
60 int32_t dot_acc(vec<uint8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
61  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
62  c;
63 }
64 
65 int32_t dot_acc(vec<uint8_t, 4> a, vec<int8_t, 4> b, int32_t c) {
66  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
67  c;
68 }
69 
70 int32_t dot_acc(vec<int8_t, 4> a, vec<uint8_t, 4> b, int32_t c) {
71  return a.s0() * b.s0() + a.s1() * b.s1() + a.s2() * b.s2() + a.s3() * b.s3() +
72  c;
73 }
74 
75 } // namespace oneapi
76 } // namespace ext
77 
78 } // namespace sycl
79 } // __SYCL_INLINE_NAMESPACE(cl)
cl::sycl::ext::oneapi::Uu::i
uint32_t i
Definition: dot_product.hpp:24
cl::sycl
Definition: access.hpp:14
sycl
Definition: invoke_simd.hpp:68
cl::sycl::ext::oneapi::Us
Definition: dot_product.hpp:18
cl::sycl::ext::oneapi::Us::i
int32_t i
Definition: dot_product.hpp:20
cl::sycl::ext::oneapi::dot_acc
int32_t dot_acc(vec< int8_t, 4 > a, vec< uint8_t, 4 > b, int32_t c)
Definition: dot_product.hpp:70
cl
We provide new interfaces for matrix muliply in this patch:
Definition: access.hpp:13
cl::sycl::image_channel_order::a
@ a
cl::sycl::vec
Provides a cross-patform vector class template that works efficiently on SYCL devices as well as in h...
Definition: aliases.hpp:18
cl::sycl::ext::oneapi::Uu
Definition: dot_product.hpp:22
__SYCL_INLINE_NAMESPACE
#define __SYCL_INLINE_NAMESPACE(X)
Definition: defines_elementary.hpp:12