XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
selector_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/gemm/api.hpp"
23#include "group/gemm/compute_policy.hpp"
24
25namespace gpu::xetla::group {
26namespace detail {
27
28template <typename dtype_a, typename dtype_b, int alignment_a, int alignment_b,
29 gpu_arch arch_tag>
31 using load_store_attr = typename arch_attr_t<
32 arch_tag>::template load_store_attr<msg_type::block_2d>;
33 static constexpr int alignment_bytes = load_store_attr::alignment_in_bytes;
34 static constexpr int alignment_bytes_a = alignment_a * sizeof(dtype_a);
35 static constexpr int alignment_bytes_b = alignment_b * sizeof(dtype_b);
36
37public:
38 static constexpr bool value = (alignment_bytes_a % alignment_bytes == 0)
39 && (alignment_bytes_b % alignment_bytes == 0);
40};
41
42} // namespace detail
43
46
48template <typename dtype_a, typename dtype_b, mem_layout mem_layout_a,
49 mem_layout mem_layout_b, mem_space mem_space_a, mem_space mem_space_b,
50 int alignment_a, int alignment_b, typename dtype_acc,
51 typename tile_shape, int k_stride, gpu_arch arch_tag, int stages,
52 int sync_freq>
53class gemm_selector_t<dtype_a, dtype_b, mem_layout_a, mem_layout_b, mem_space_a,
54 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
55 mma_engine::xmx, arch_tag, stages, sync_freq,
56 std::enable_if_t<detail::check_2d_block_pitch_alignment<dtype_a,
57 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
58 using mem_desc_a
60 using mem_desc_b
65 perf_tuning_knob, arch_tag>;
67
68public:
71};
72
74template <typename dtype_a, typename dtype_b, mem_layout mem_layout_a,
75 mem_layout mem_layout_b, mem_space mem_space_a, mem_space mem_space_b,
76 int alignment_a, int alignment_b, typename dtype_acc,
77 typename tile_shape, int k_stride, gpu_arch arch_tag, int stages,
78 int sync_freq>
79class gemm_selector_t<dtype_a, dtype_b, mem_layout_a, mem_layout_b, mem_space_a,
80 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
81 mma_engine::xmx, arch_tag, stages, sync_freq,
82 std::enable_if_t<!detail::check_2d_block_pitch_alignment<dtype_a,
83 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
84 using mem_desc_a
86 using mem_desc_b
91 perf_tuning_knob, arch_tag>;
93
94public:
97};
98
100template <typename dtype_a, typename dtype_b, mem_layout mem_layout_a,
101 mem_layout mem_layout_b, mem_space mem_space_a, mem_space mem_space_b,
102 int alignment_a, int alignment_b, typename dtype_acc,
103 typename tile_shape, int k_stride, gpu_arch arch_tag, int stages,
104 int sync_freq>
105class gemm_selector_t<dtype_a, dtype_b, mem_layout_a, mem_layout_b, mem_space_a,
106 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
107 mma_engine::fpu, arch_tag, stages, sync_freq,
108 std::enable_if_t<detail::check_2d_block_pitch_alignment<dtype_a,
109 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
110 static_assert(std::is_same<dtype_a, dtype_acc>::value
111 && std::is_same<dtype_b, dtype_acc>::value,
112 "When use gemm_selector, dtype_a and dtype_b in fpu based gemm"
113 "should be the same as dtype_acc");
114 using mem_desc_a
116 using mem_desc_b
121 perf_tuning_knob, arch_tag>;
123
124public:
127};
128
130} // namespace gpu::xetla::group
static constexpr bool value
Definition selector_xe.hpp:38
Gemm selection functor.
Definition api.hpp:75
Gemm functor.
Definition api.hpp:52
C++ API.
Definition limitation.hpp:607
mem_space
Definition common.hpp:77
mma_engine
Definition common.hpp:225
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for fpu engine.
Definition compute_policy.hpp:105
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Definition memory_descriptor.hpp:139