27template <
typename matAcc_dst_t_,
typename matAcc_src_t_,
typename matB_t_,
28 typename matA_t_,
gpu_arch arch_tag_>
29struct tile_mma_t<matAcc_dst_t_, matAcc_src_t_, matB_t_, matA_t_,
31 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
43 static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y;
44 static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x;
45 static constexpr uint32_t a_tile_elems = matA_t::tile_elems;
46 static constexpr uint32_t a_block_size_y = matA_t::block_size_y;
47 static constexpr uint32_t a_block_size_x = matA_t::block_size_x;
48 static constexpr uint32_t a_block_elems = matA_t::block_elems;
50 static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x;
51 static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y;
52 static constexpr uint32_t b_tile_elems = matB_t::tile_elems;
53 static constexpr uint32_t b_block_size_x = matB_t::block_size_x;
54 static constexpr uint32_t b_block_size_y = matB_t::block_size_y;
55 static constexpr uint32_t b_block_elems = matB_t::block_elems;
57 static constexpr uint32_t tile_size_m = matDst_t::tile_size_y;
58 static constexpr uint32_t tile_size_k = a_tile_size_x;
59 static constexpr uint32_t tile_size_n = matDst_t::tile_size_x;
60 static constexpr uint32_t tile_elems = tile_size_m * tile_size_n;
61 static constexpr uint32_t block_size_n = matDst_t::block_size_x;
62 static constexpr uint32_t block_size_k = a_block_size_x;
63 static constexpr uint32_t block_size_m = matDst_t::block_size_y;
64 static constexpr uint32_t block_elems = block_size_m * block_size_n;
66 static_assert(tile_size_m == matA_t::tile_size_y,
67 "matAcc tile m should match with matA tile m");
68 static_assert(a_tile_size_x == b_tile_size_y,
69 "matA tile k should match with matB tile k");
70 static_assert(tile_size_n == matB_t::tile_size_x,
71 "matAcc tile n should match with matB tile n");
72 static_assert(block_size_m == a_block_size_y,
73 "matAcc block m should match with matA block m");
74 static_assert(block_size_n == b_block_size_x,
75 "matAcc block n should match with matB block n");
76 static_assert(a_block_size_x == b_block_size_y,
77 "matA block w should match with matB block h");
78 static_assert((tile_size_k % block_size_k) == 0,
79 "matAcc tile_size_k should be a multiple of block_size_k");
80 static_assert((block_size_k == 32 /
sizeof(
dtype_a)),
81 "DPAS depth only support the value of 32 / sizeof(dtype_a). "
82 "Currently we don't support the "
83 "splitting of block when call the DPAS");
85 static constexpr int32_t num_block_n = matDst_t::num_block_x;
86 static constexpr int32_t num_block_m = matDst_t::num_block_y;
87 static constexpr int32_t num_block_k = tile_size_k / block_size_k;
89 static constexpr int32_t mma_m = mma_attr::mma_m_in_elem;
90 static constexpr int32_t mma_k
91 = mma_attr::mma_k_in_bytes /
sizeof(uint32_t);
92 static_assert(tile_size_m % mma_m == 0,
93 "tile_size_m shoud be a multiple of mma_m");
97 constexpr int32_t a_mma_elems = mma_m * a_block_size_x;
98 constexpr int32_t c_mma_elems = mma_m * block_size_n;
100 for (uint32_t j = 0; j < num_block_n; j++) {
102 for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) {
103 auto src_block = src.reg.xetla_select<block_elems, 1>(
104 (i * num_block_n + j) * block_elems);
105 auto dst_block = dst.reg.xetla_select<block_elems, 1>(
106 (i * num_block_n + j) * block_elems);
108 for (uint32_t mma_i = 0; mma_i < block_size_m / mma_m;
110 auto src_sub_blk = src_block.xetla_select<c_mma_elems, 1>(
111 mma_i * c_mma_elems);
112 auto dst_sub_blk = dst_block.xetla_select<c_mma_elems, 1>(
113 mma_i * c_mma_elems);
115 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
116 (i * num_block_k) * a_block_elems);
117 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
118 mma_i * a_mma_elems);
119 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
126 mma_k, mma_m,
dtype_src, uint32_t, uint32_t,
129 / (
sizeof(uint32_t) /
sizeof(
dtype_b)),
131 / (
sizeof(uint32_t) /
sizeof(
dtype_a))>(
132 src_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
133 a_sub_blk.xetla_format<uint32_t>());
137 for (uint32_t k = 1; k < num_block_k; k++) {
138 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
139 (i * num_block_k + k) * a_block_elems);
140 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
141 mma_i * a_mma_elems);
142 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
143 (j + k * num_block_n) * b_block_elems);
149 mma_k, mma_m,
dtype_src, uint32_t, uint32_t,
152 / (
sizeof(uint32_t) /
sizeof(
dtype_b)),
154 / (
sizeof(uint32_t) /
sizeof(
dtype_a))>(
155 dst_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
156 a_sub_blk.xetla_format<uint32_t>());
160 if constexpr ((tile_size_m % block_size_m) != 0) {
161 constexpr uint32_t tail_block_size_m
162 = tile_size_m % block_size_m;
163 constexpr uint32_t tail_block_elems
164 = block_size_n * tail_block_size_m;
165 constexpr uint32_t a_tail_block_elems
166 = tail_block_size_m * a_block_size_x;
167 constexpr uint32_t tail_m_start
168 = tile_size_m / block_size_m * block_size_m;
169 constexpr uint32_t tail_elems_start
170 = tail_m_start * tile_size_n;
171 constexpr uint32_t a_tail_elems_start
172 = tail_m_start * a_tile_size_x;
173 auto src_block = src.reg.xetla_select<tail_block_elems, 1>(
174 tail_elems_start + j * tail_block_elems);
175 auto dst_block = dst.reg.xetla_select<tail_block_elems, 1>(
176 tail_elems_start + j * tail_block_elems);
178 for (uint32_t mma_i = 0; mma_i < tail_block_size_m / mma_m;
180 auto src_sub_blk = src_block.xetla_select<c_mma_elems, 1>(
181 mma_i * c_mma_elems);
182 auto dst_sub_blk = dst_block.xetla_select<c_mma_elems, 1>(
183 mma_i * c_mma_elems);
186 = a.reg.xetla_select<a_tail_block_elems, 1>(
188 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
189 mma_i * a_mma_elems);
190 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
197 mma_k, mma_m,
dtype_src, uint32_t, uint32_t,
200 / (
sizeof(uint32_t) /
sizeof(
dtype_b)),
202 / (
sizeof(uint32_t) /
sizeof(
dtype_a))>(
203 src_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
204 a_sub_blk.xetla_format<uint32_t>());
207 for (uint32_t k = 1; k < num_block_k; k++) {
209 = a.reg.xetla_select<a_tail_block_elems, 1>(
211 + k * a_tail_block_elems);
212 auto a_sub_blk = a_block.xetla_select<a_mma_elems, 1>(
213 mma_i * a_mma_elems);
214 auto b_sub_blk = b.reg.xetla_select<b_block_elems, 1>(
215 (j + k * num_block_n) * b_block_elems);
221 mma_k, mma_m,
dtype_src, uint32_t, uint32_t,
224 / (
sizeof(uint32_t) /
sizeof(
dtype_b)),
226 / (
sizeof(uint32_t) /
sizeof(
dtype_a))>(
227 dst_sub_blk, b_sub_blk.xetla_format<uint32_t>(),
228 a_sub_blk.xetla_format<uint32_t>());
233 if constexpr (num_block_k > 1) {
234 xetla_wait(dst.reg.xetla_format<uint16_t>()[0]);
#define __XETLA_API
Definition common.hpp:43
__XETLA_API xetla_vector< T, N > xetla_mma(xetla_vector< T, N > src0, xetla_vector< T1, N1 > src1, xetla_vector< T2, N2 > src2, Sat sat={})
description of xetla mma perform matrix multiply add operation
Definition math_mma.hpp:144
constexpr gpu::xetla::argument_type mma_argument_type()
convert normal data type to dpas argument type
Definition math_mma.hpp:35
Definition limitation.hpp:457
mma_engine
Definition common.hpp:225
gpu_arch
Definition common.hpp:73
void xetla_wait(uint16_t val)
Definition common.hpp:229
Definition arch_config.hpp:72
matAcc_dst_t_ matDst_t
Definition mma_xe.hpp:35
matA_t_ matA_t
Definition mma_xe.hpp:32
matAcc_src_t_ matSrc_t
Definition mma_xe.hpp:34
typename matB_t::dtype dtype_b
Definition mma_xe.hpp:37
typename matA_t::dtype dtype_a
Definition mma_xe.hpp:36
typename arch_attr_t< arch_tag_ >::mma_attr mma_attr
Definition mma_xe.hpp:41
typename matSrc_t::dtype dtype_src
Definition mma_xe.hpp:38
matB_t_ matB_t
Definition mma_xe.hpp:33
typename matDst_t::dtype dtype_dst
Definition mma_xe.hpp:39
static __XETLA_API void mma(matDst_t &dst, matSrc_t &src, matB_t &b, matA_t &a)
Definition mma_xe.hpp:95
Is the xetla tile mma operation definition API.
Definition api.hpp:36