XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
common.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#pragma once
18
19#include "common/common.hpp"
20#include "group/group.hpp"
21#include "subgroup/subgroup.hpp"
22
23namespace gpu::xetla {
24
25// enum
26
27enum class tune_key : uint8_t {
56};
57
58enum class tune_key_value : uint8_t {
66};
67
68// parameter optimizer
69
70enum class param_optimizer_tag : uint8_t { kernel, work_group };
71
72template <param_optimizer_tag tag_, typename dict_t_>
74
76 template <typename T, typename U>
78 static constexpr bool value = []() constexpr {
79 bool valid = true;
80 valid &= std::is_same<typename T::template find_elem_t<
82 typename U::template find_elem_t<
84 valid &= T::template find_elem_v<tune_key::
85 memory_layout_a> == U::template find_elem_v<tune_key::memory_layout_a>;
86 valid &= T::template find_elem_v<tune_key::
87 memory_alignment_a> == U::template find_elem_v<tune_key::memory_alignment_a>;
88 valid &= std::is_same<typename T::template find_elem_t<
90 typename U::template find_elem_t<
92 valid &= T::template find_elem_v<tune_key::
93 memory_layout_b> == U::template find_elem_v<tune_key::memory_layout_b>;
94 valid &= T::template find_elem_v<tune_key::
95 memory_alignment_b> == U::template find_elem_v<tune_key::memory_alignment_b>;
96 valid &= std::is_same<typename T::template find_elem_t<
98 typename U::template find_elem_t<
100 valid &= T::template find_elem_v<tune_key::
101 memory_layout_c> == U::template find_elem_v<tune_key::memory_layout_c>;
102 valid &= T::template find_elem_v<tune_key::
103 memory_alignment_c> == U::template find_elem_v<tune_key::memory_alignment_c>;
104 valid &= T::template find_elem_v<tune_key::
105 gpu_arch> == U::template find_elem_v<tune_key::gpu_arch>;
106 return valid;
107 }
108 ();
109 };
110};
111
112// parameter adaptor
113
114enum class param_adaptor_tag : uint8_t {
115 kernel,
118};
119
120template <param_adaptor_tag tag_, typename dict_t_>
122
123template <typename dict_t_>
125 using dtype_acc = typename dict_t_::template find_elem_t<
127 using wg_tile_shape = typename dict_t_::template find_elem_t<
129 static constexpr uint32_t wg_tile_n = wg_tile_shape::template dim<0>();
130 static constexpr uint32_t wg_tile_m = wg_tile_shape::template dim<1>();
131 static constexpr uint32_t wg_tile_k
132 = dict_t_::template find_elem_v<tune_key::wg_tile_k>;
133 using sg_tile_shape = typename dict_t_::template find_elem_t<
135 static constexpr uint32_t sg_tile_n = sg_tile_shape::template dim<0>();
136 static constexpr uint32_t sg_tile_m = sg_tile_shape::template dim<1>();
137 static constexpr uint32_t prefetch_distance
138 = dict_t_::template find_elem_v<tune_key::prefetch_distance>;
139 static constexpr uint32_t periodic_sync_interval
140 = dict_t_::template find_elem_v<tune_key::periodic_sync_interval>;
141 static constexpr auto mma_engine_tag
142 = dict_t_::template find_elem_v<tune_key::mma_engine>;
143 static constexpr auto gpu_arch_tag
144 = dict_t_::template find_elem_v<tune_key::gpu_arch>;
145
146 // Org the compute shape for sub-matrix
147 using tile_shape = group::tile_shape_t<wg_tile_n, // workgroup size in dim0
148 wg_tile_m, // workgroup size in dim1
149 sg_tile_n, // subgroup size in dim0
150 sg_tile_m>; // subgroup size in dim1
151};
152
153} // namespace gpu::xetla
154
Definition arch_config.hpp:24
param_adaptor_tag
Definition common.hpp:114
mma_engine
Definition common.hpp:225
tune_key
Definition common.hpp:27
gpu_arch
Definition common.hpp:73
tune_key_value
Definition common.hpp:58
param_optimizer_tag
Definition common.hpp:70
Workgroup level tile shape description.
Definition tile_shape.hpp:34
Definition common.hpp:124
typename dict_t_::template find_elem_t< tune_key::sg_tile_shape >::type sg_tile_shape
Definition common.hpp:134
static constexpr uint32_t periodic_sync_interval
Definition common.hpp:140
static constexpr uint32_t sg_tile_n
Definition common.hpp:135
typename dict_t_::template find_elem_t< tune_key::data_type_acc >::type dtype_acc
Definition common.hpp:126
static constexpr auto mma_engine_tag
Definition common.hpp:142
static constexpr auto gpu_arch_tag
Definition common.hpp:144
static constexpr uint32_t sg_tile_m
Definition common.hpp:136
static constexpr uint32_t wg_tile_n
Definition common.hpp:129
typename dict_t_::template find_elem_t< tune_key::wg_tile_shape >::type wg_tile_shape
Definition common.hpp:128
static constexpr uint32_t prefetch_distance
Definition common.hpp:138
static constexpr uint32_t wg_tile_m
Definition common.hpp:130
static constexpr uint32_t wg_tile_k
Definition common.hpp:132
Definition common.hpp:121
static constexpr bool value
Definition common.hpp:78
Definition common.hpp:75
Definition common.hpp:73