DPC++ Runtime
Runtime libraries for oneAPI DPC++
query-types.hpp
Go to the documentation of this file.
1 //==--------------- query-types.hpp - SYCL matrix --------------*- C++ -*---==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 // ===--------------------------------------------------------------------=== //
8 
9 #pragma once
10 
11 #include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
13 
14 namespace sycl {
15 inline namespace _V1 {
16 namespace ext::oneapi::experimental::matrix {
17 
18 enum class matrix_type {
19  bf16,
20  fp16,
21  tf32,
22  fp32,
23  fp64,
24  sint8,
25  sint16,
26  sint32,
27  sint64,
28  uint8,
29  uint16,
30  uint32,
31  uint64
32 };
33 
34 struct combination {
35  size_t max_msize;
36  size_t max_nsize;
37  size_t max_ksize;
38  size_t msize;
39  size_t nsize;
40  size_t ksize;
45 };
46 
47 } // namespace ext::oneapi::experimental::matrix
48 
49 // Type to matrix type string conversion used in compile-time
50 namespace detail {
51 template <typename T> constexpr const char *convertTypeToMatrixTypeString() {
52  return "";
53 }
54 template <>
55 constexpr const char *
56 convertTypeToMatrixTypeString<sycl::ext::oneapi::bfloat16>() {
57  return "matrix_type::bf16";
58 }
59 template <> constexpr const char *convertTypeToMatrixTypeString<sycl::half>() {
60  return "matrix_type::fp16";
61 }
62 template <>
63 constexpr const char *convertTypeToMatrixTypeString<
65  return "matrix_type::tf32";
66 }
67 template <> constexpr const char *convertTypeToMatrixTypeString<float>() {
68  return "matrix_type::fp32";
69 }
70 template <> constexpr const char *convertTypeToMatrixTypeString<double>() {
71  return "matrix_type::fp64";
72 }
73 template <> constexpr const char *convertTypeToMatrixTypeString<int8_t>() {
74  return "matrix_type::sint8";
75 }
76 template <> constexpr const char *convertTypeToMatrixTypeString<int16_t>() {
77  return "matrix_type::sint16";
78 }
79 template <> constexpr const char *convertTypeToMatrixTypeString<int32_t>() {
80  return "matrix_type::sint32";
81 }
82 template <> constexpr const char *convertTypeToMatrixTypeString<int64_t>() {
83  return "matrix_type::sint64";
84 }
85 template <> constexpr const char *convertTypeToMatrixTypeString<uint8_t>() {
86  return "matrix_type::uint8";
87 }
88 template <> constexpr const char *convertTypeToMatrixTypeString<uint16_t>() {
89  return "matrix_type::uint16";
90 }
91 template <> constexpr const char *convertTypeToMatrixTypeString<uint32_t>() {
92  return "matrix_type::uint32";
93 }
94 template <> constexpr const char *convertTypeToMatrixTypeString<uint64_t>() {
95  return "matrix_type::uint64";
96 }
97 } // namespace detail
98 } // namespace _V1
99 } // namespace sycl
constexpr const char * convertTypeToMatrixTypeString< int16_t >()
Definition: query-types.hpp:76
constexpr const char * convertTypeToMatrixTypeString< int64_t >()
Definition: query-types.hpp:82
constexpr const char * convertTypeToMatrixTypeString< int32_t >()
Definition: query-types.hpp:79
constexpr const char * convertTypeToMatrixTypeString< uint64_t >()
Definition: query-types.hpp:94
constexpr const char * convertTypeToMatrixTypeString< uint16_t >()
Definition: query-types.hpp:88
constexpr const char * convertTypeToMatrixTypeString< uint8_t >()
Definition: query-types.hpp:85
constexpr const char * convertTypeToMatrixTypeString< float >()
Definition: query-types.hpp:67
constexpr const char * convertTypeToMatrixTypeString< double >()
Definition: query-types.hpp:70
constexpr const char * convertTypeToMatrixTypeString< int8_t >()
Definition: query-types.hpp:73
constexpr const char * convertTypeToMatrixTypeString< uint32_t >()
Definition: query-types.hpp:91
constexpr const char * convertTypeToMatrixTypeString()
Definition: query-types.hpp:51
Definition: access.hpp:18