XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
mma_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
23
24namespace gpu::xetla::subgroup {
25
27template <typename matAcc_dst_t_, typename matAcc_src_t_, typename matB_t_,
28 typename matA_t_, gpu_arch arch_tag_>
29struct tile_mma_t<matAcc_dst_t_, matAcc_src_t_, matB_t_, matA_t_,
30 mma_engine::xmx, arch_tag_,
31 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
32 using matA_t = matA_t_;
33 using matB_t = matB_t_;
34 using matSrc_t = matAcc_src_t_;
35 using matDst_t = matAcc_dst_t_;
36 using dtype_a = typename matA_t::dtype;
37 using dtype_b = typename matB_t::dtype;
38 using dtype_src = typename matSrc_t::dtype;
39 using dtype_dst = typename matDst_t::dtype;
40
42
43 static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y;
44 static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x;
45 static constexpr uint32_t a_tile_elems = matA_t::tile_elems;
46 static constexpr uint32_t a_block_size_y = matA_t::block_size_y;
47 static constexpr uint32_t a_block_size_x = matA_t::block_size_x;
48 static constexpr uint32_t a_block_elems = matA_t::block_elems;
49
50 static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x;
51 static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y;
52 static constexpr uint32_t b_tile_elems = matB_t::tile_elems;
53 static constexpr uint32_t b_block_size_x = matB_t::block_size_x;
54 static constexpr uint32_t b_block_size_y = matB_t::block_size_y;
55 static constexpr uint32_t b_block_elems = matB_t::block_elems;
56
57 static constexpr uint32_t tile_size_m = matDst_t::tile_size_y;
58 static constexpr uint32_t tile_size_k = a_tile_size_x;
59 static constexpr uint32_t tile_size_n = matDst_t::tile_size_x;
60 static constexpr uint32_t tile_elems = tile_size_m * tile_size_n;
61 static constexpr uint32_t block_size_n = matDst_t::block_size_x;
62 static constexpr uint32_t block_size_k = a_block_size_x;
63 static constexpr uint32_t block_size_m = matDst_t::block_size_y;
64 static constexpr uint32_t block_elems = block_size_m * block_size_n;
65
66 static_assert(tile_size_m == matA_t::tile_size_y,
67 "matAcc tile m should match with matA tile m");
68 static_assert(a_tile_size_x == b_tile_size_y,
69 "matA tile k should match with matB tile k");
70 static_assert(tile_size_n == matB_t::tile_size_x,
71 "matAcc tile n should match with matB tile n");
72 static_assert(block_size_m == a_block_size_y,
73 "matAcc block m should match with matA block m");
74 static_assert(block_size_n == b_block_size_x,
75 "matAcc block n should match with matB block n");
76 static_assert(a_block_size_x == b_block_size_y,
77 "matA block w should match with matB block h");
78 static_assert((tile_size_k % block_size_k) == 0,
79 "matAcc tile_size_k should be a multiple of block_size_k");
80 static_assert((block_size_k == 32 / sizeof(dtype_a)),
81 "DPAS depth only support the value of 32 / sizeof(dtype_a). "
82 "Currently we don't support the "
83 "splitting of block when call the DPAS");
84
85 static constexpr int32_t num_block_n = matDst_t::num_block_x;
86 static constexpr int32_t num_block_m = matDst_t::num_block_y;
87 static constexpr int32_t num_block_k = tile_size_k / block_size_k;
88
89 static constexpr int32_t mma_m = mma_attr::mma_m_in_elem;
90 static constexpr int32_t mma_k
91 = mma_attr::mma_k_in_bytes / sizeof(uint32_t);
92 static_assert(tile_size_m % mma_m == 0,
93 "tile_size_m shoud be a multiple of mma_m");
94
95 __XETLA_API static void mma(
96 matDst_t &dst, matSrc_t &src, matB_t &b, matA_t &a) {
97 constexpr int32_t a_mma_elems = mma_m * a_block_size_x;
98 constexpr int32_t c_mma_elems = mma_m * block_size_n;
99#pragma unroll
100 for (uint32_t j = 0; j < num_block_n; j++) {
101#pragma unroll
102 for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) {
103 auto src_block = src.reg.xetla_select<block_elems, 1>(
104 (i * num_block_n + j) * block_elems);
105 auto dst_block = dst.reg.xetla_select<block_elems, 1>(
106 (i * num_block_n + j) * block_elems);
107#pragma unroll
108 for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m;
109 mma_i++) {
110 auto src_sub_blk = src_block.xetla_select<c_mma_elems, 1>(
111 mma_i * c_mma_elems);
112 auto dst_sub_blk = dst_block.xetla_select<c_mma_elems, 1>(
113 mma_i * c_mma_elems);
114 { //k=0
115 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
116 (i * num_block_k) * a_block_elems);
117 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
118 mma_i * a_mma_elems);
119 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
120 j * b_block_elems);
121 dst_sub_blk = xetla_mma<
123 dtype_b>(),
125 dtype_a>(),
126 mma_k, mma_m, dtype_src, uint32_t, uint32_t,
127 c_mma_elems,
128 b_block_elems
129 / (sizeof(uint32_t) / sizeof(dtype_b)),
130 a_mma_elems
131 / (sizeof(uint32_t) / sizeof(dtype_a))>(
132 src_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
133 a_sub_blk.xetla_format<uint32_t>());
134 }
135
136#pragma unroll
137 for (uint32_t k = 1; k < num_block_k; k++) {
138 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
139 (i * num_block_k + k) * a_block_elems);
140 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
141 mma_i * a_mma_elems);
142 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
143 (j + k * num_block_n) * b_block_elems);
144 dst_sub_blk = xetla_mma<
146 dtype_b>(),
148 dtype_a>(),
149 mma_k, mma_m, dtype_src, uint32_t, uint32_t,
150 c_mma_elems,
151 b_block_elems
152 / (sizeof(uint32_t) / sizeof(dtype_b)),
153 a_mma_elems
154 / (sizeof(uint32_t) / sizeof(dtype_a))>(
155 dst_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
156 a_sub_blk.xetla_format<uint32_t>());
157 }
158 }
159 }
160 if constexpr ((tile_size_m % block_size_m) != 0) {
161 constexpr uint32_t tail_block_size_m
162 = tile_size_m % block_size_m;
163 constexpr uint32_t tail_block_elems
164 = block_size_n * tail_block_size_m;
165 constexpr uint32_t a_tail_block_elems
166 = tail_block_size_m * a_block_size_x;
167 constexpr uint32_t tail_m_start
168 = tile_size_m / block_size_m * block_size_m;
169 constexpr uint32_t tail_elems_start
170 = tail_m_start * tile_size_n;
171 constexpr uint32_t a_tail_elems_start
172 = tail_m_start * a_tile_size_x;
173 auto src_block = src.reg.xetla_select<tail_block_elems, 1>(
174 tail_elems_start + j * tail_block_elems);
175 auto dst_block = dst.reg.xetla_select<tail_block_elems, 1>(
176 tail_elems_start + j * tail_block_elems);
177#pragma unroll
178 for (uint32_t mma_i = 0; mma_i < tail_block_size_m / mma_m;
179 mma_i++) {
180 auto src_sub_blk = src_block.xetla_select<c_mma_elems, 1>(
181 mma_i * c_mma_elems);
182 auto dst_sub_blk = dst_block.xetla_select<c_mma_elems, 1>(
183 mma_i * c_mma_elems);
184 { //k=0
185 auto a_block
186 = a.reg.xetla_select<a_tail_block_elems, 1>(
187 a_tail_elems_start);
188 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
189 mma_i * a_mma_elems);
190 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
191 j * b_block_elems);
192 dst_sub_blk = xetla_mma<
194 dtype_b>(),
196 dtype_a>(),
197 mma_k, mma_m, dtype_src, uint32_t, uint32_t,
198 c_mma_elems,
199 b_block_elems
200 / (sizeof(uint32_t) / sizeof(dtype_b)),
201 a_mma_elems
202 / (sizeof(uint32_t) / sizeof(dtype_a))>(
203 src_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
204 a_sub_blk.xetla_format<uint32_t>());
205 }
206#pragma unroll
207 for (uint32_t k = 1; k < num_block_k; k++) {
208 auto a_block
209 = a.reg.xetla_select<a_tail_block_elems, 1>(
210 a_tail_elems_start
211 + k * a_tail_block_elems);
212 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
213 mma_i * a_mma_elems);
214 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
215 (j + k * num_block_n) * b_block_elems);
216 dst_sub_blk = xetla_mma<
218 dtype_b>(),
220 dtype_a>(),
221 mma_k, mma_m, dtype_src, uint32_t, uint32_t,
222 c_mma_elems,
223 b_block_elems
224 / (sizeof(uint32_t) / sizeof(dtype_b)),
225 a_mma_elems
226 / (sizeof(uint32_t) / sizeof(dtype_a))>(
227 dst_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
228 a_sub_blk.xetla_format<uint32_t>());
229 }
230 }
231 }
232 }
233 if constexpr (num_block_k > 1) {
234 xetla_wait(dst.reg.xetla_format<uint16_t>()[0]);
235 }
236 }
237};
238
239} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
__XETLA_API xetla_vector< T, N > xetla_mma(xetla_vector< T, N > src0, xetla_vector< T1, N1 > src1, xetla_vector< T2, N2 > src2, Sat sat={})
description of xetla mma perform matrix multiply add operation
Definition math_mma.hpp:144
constexpr gpu::xetla::argument_type mma_argument_type()
convert normal data type to dpas argument type
Definition math_mma.hpp:35
Definition limitation.hpp:457
mma_engine
Definition common.hpp:225
gpu_arch
Definition common.hpp:73
void xetla_wait(uint16_t val)
Definition common.hpp:229
Definition arch_config.hpp:72
Is the xetla tile mma operation definition API.
Definition api.hpp:36
C++ API.