27template <
typename matAcc_dst_t_,
typename matAcc_src_t_,
typename matB_t_,
28 typename matA_t_,
gpu_arch arch_tag_>
29struct tile_mma_t<matAcc_dst_t_, matAcc_src_t_, matB_t_, matA_t_,
31 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
44 static_assert(matA_t::reg_transpose,
45 "For FMAOp GEMM, the register layout of matA should be col-major");
46 static_assert(!matB_t::reg_transpose,
47 "For FMAOp GEMM, the register layout of matB should be row-major");
49 static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y;
50 static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x;
51 static constexpr uint32_t a_tile_elems = matA_t::tile_elems;
52 static constexpr uint32_t a_block_size_w = matA_t::block_size_y;
53 static constexpr uint32_t a_block_size_h = matA_t::block_size_x;
54 static constexpr uint32_t a_block_elems = matA_t::block_elems;
56 static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x;
57 static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y;
58 static constexpr uint32_t b_tile_elems = matB_t::tile_elems;
59 static constexpr uint32_t b_block_size_x = matB_t::block_size_x;
60 static constexpr uint32_t b_block_size_y = matB_t::block_size_y;
61 static constexpr uint32_t b_block_elems = matB_t::block_elems;
63 static constexpr uint32_t tile_size_m = matDst_t::tile_size_y;
64 static constexpr uint32_t tile_size_k = a_tile_size_x;
65 static constexpr uint32_t tile_size_n = matDst_t::tile_size_x;
66 static constexpr uint32_t tile_elems = tile_size_m * tile_size_n;
67 static constexpr uint32_t block_size_n = matDst_t::block_size_x;
68 static constexpr uint32_t block_size_k = a_block_size_h;
69 static constexpr uint32_t block_size_m = matDst_t::block_size_y;
70 static constexpr uint32_t block_elems = block_size_m * block_size_n;
72 static_assert(tile_size_m == matA_t::tile_size_y,
73 "matAcc tile m should match with matA tile m");
74 static_assert(a_tile_size_x == b_tile_size_y,
75 "matA tile k should match with matB tile k");
76 static_assert(tile_size_n == matB_t::tile_size_x,
77 "matAcc tile n should match with matB tile n");
78 static_assert(block_size_m == a_block_size_w,
79 "matAcc block m should match with matA block m");
80 static_assert(block_size_n == b_block_size_x,
81 "matAcc block n should match with matB block n");
82 static_assert((tile_size_k % block_size_k) == 0,
83 "matAcc tile_size_k should be a multiple of block_size_k");
85 static constexpr int32_t num_block_n = matDst_t::num_block_x;
86 static constexpr int32_t num_block_m = matDst_t::num_block_y;
87 static constexpr int32_t num_block_k = tile_size_k / block_size_k;
89 static constexpr int32_t mma_m = register_attr::acc_reg_in_bytes
92 template <
int blk_m,
int blk_n,
int blk_k>
98 auto dst_blk_2d = dst.xetla_format<
dtype_dst, blk_m, blk_n>();
99 auto b_blk_2d = b_block.xetla_format<
dtype_dst, blk_k, blk_n>();
100 auto src_blk_2d = src.xetla_format<
dtype_src, blk_m, blk_n>();
102 for (uint32_t i = 0; i < blk_m / mma_m; i++) {
104 auto dst_tmp_2d = dst_tmp.xetla_format<
dtype_dst, mma_m, blk_n>();
106 for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) {
107 dst_tmp_2d.row(i_acc)
108 = a_block[i_acc + i * mma_m] * b_blk_2d.row(0)
109 + src_blk_2d.row(i_acc + i * mma_m);
112 for (uint32_t k = 1; k < blk_k - 1; k++) {
113 for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) {
114 int a_offset = k * blk_m + i_acc + i * mma_m;
115 dst_tmp_2d.row(i_acc)
116 += a_block[a_offset] * b_blk_2d.row(k);
119 for (uint32_t i_acc = 0; i_acc < mma_m; i_acc++) {
120 int a_offset = (blk_k - 1) * blk_m + i_acc + i * mma_m;
121 dst_blk_2d.row(i_acc + i * mma_m)
122 = a_block[a_offset] * b_blk_2d.row(blk_k - 1)
123 + dst_tmp_2d.row(i_acc);
128 if constexpr ((blk_m % mma_m) != 0) {
129 constexpr uint32_t tail_start_m = blk_m / mma_m * mma_m;
130 constexpr uint32_t tail_m = blk_m % mma_m;
132 auto dst_tmp_2d = dst_tmp.xetla_format<
dtype_dst, tail_m, blk_n>();
134 for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) {
135 dst_tmp_2d.row(i_acc)
136 = a_block[i_acc + tail_start_m] * b_blk_2d.row(0)
137 + src_blk_2d.row(i_acc + tail_start_m);
140 for (uint32_t k = 1; k < blk_k - 1; k++) {
141 for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) {
142 int a_offset = k * blk_m + i_acc + tail_start_m;
143 dst_tmp_2d.row(i_acc)
144 += a_block[a_offset] * b_blk_2d.row(k);
147 for (uint32_t i_acc = 0; i_acc < tail_m; i_acc++) {
148 int a_offset = (blk_k - 1) * blk_m + i_acc + tail_start_m;
149 dst_blk_2d.row(i_acc + tail_start_m)
150 = a_block[a_offset] * b_blk_2d.row(blk_k - 1)
151 + dst_tmp_2d.row(i_acc);
160 = b.reg.xetla_select<b_block_size_y * b_tile_size_x, 1>(0);
162 for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) {
163 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
164 i * num_block_k * a_block_elems);
166 for (uint32_t j = 0; j < num_block_n; j++) {
167 auto b_block = b_reg.xetla_select<b_block_elems, 1>(
169 auto src_block = src.reg.xetla_select<block_elems, 1>(
170 (i * num_block_n + j) * block_elems);
171 auto dst_block = dst.reg.xetla_select<block_elems, 1>(
172 (i * num_block_n + j) * block_elems);
173 mma_core<block_size_m, block_size_n, block_size_k>(
174 dst_block, src_block, b_block, a_block);
179 if constexpr ((tile_size_m % block_size_m) != 0) {
180 constexpr uint32_t tail_start_m
181 = tile_size_m / block_size_m * block_size_m;
182 constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m;
183 constexpr uint32_t a_tail_blk_elems
184 = a_block_size_h * a_tail_blk_w;
185 constexpr uint32_t tail_size_m = tile_size_m - tail_start_m;
186 constexpr uint32_t acc_tail_blk_elems
187 = tail_size_m * block_size_n;
188 auto a_block = a.reg.xetla_select<a_tail_blk_elems, 1>(
189 a_tile_size_x * tail_start_m);
191 for (uint32_t j = 0; j < num_block_n; j++) {
192 auto b_block = b_reg.xetla_select<b_block_elems, 1>(
195 = src.reg.xetla_select<acc_tail_blk_elems, 1>(
196 (tail_start_m * tile_size_n)
197 + j * acc_tail_blk_elems);
199 = dst.reg.xetla_select<acc_tail_blk_elems, 1>(
200 (tail_start_m * tile_size_n)
201 + j * acc_tail_blk_elems);
202 mma_core<tail_size_m, block_size_n, block_size_k>(
203 dst_block, src_block, b_block, a_block);
209 for (uint32_t k_i = 1; k_i < num_block_k; k_i++) {
211 auto b_reg = b.reg.xetla_select<b_block_size_y * b_tile_size_x, 1>(
212 k_i * b_block_size_y * b_tile_size_x);
214 for (uint32_t i = 0; i < tile_size_m / block_size_m; i++) {
215 auto a_block = a.reg.xetla_select<a_block_elems, 1>(
216 (i * num_block_k + k_i) * a_block_elems);
218 for (uint32_t j = 0; j < num_block_n; j++) {
219 auto b_block = b_reg.xetla_select<b_block_elems, 1>(
221 auto dst_block = dst.reg.xetla_select<block_elems, 1>(
222 (i * num_block_n + j) * block_elems);
223 mma_core<block_size_m, block_size_n, block_size_k>(
224 dst_block, dst_block, b_block, a_block);
228 if constexpr ((tile_size_m % block_size_m) != 0) {
229 constexpr uint32_t tail_start_m
230 = tile_size_m / block_size_m * block_size_m;
231 constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m;
232 constexpr uint32_t a_tail_blk_elems
233 = a_block_size_h * a_tail_blk_w;
234 constexpr uint32_t tail_size_m = tile_size_m - tail_start_m;
235 constexpr uint32_t acc_tail_blk_elems
236 = tail_size_m * block_size_n;
237 auto a_block = a.reg.xetla_select<a_tail_blk_elems, 1>(
238 a_tile_size_x * tail_start_m + k_i * a_tail_blk_elems);
240 for (uint32_t j = 0; j < num_block_n; j++) {
241 auto b_block = b_reg.xetla_select<b_block_elems, 1>(
244 = dst.reg.xetla_select<acc_tail_blk_elems, 1>(
245 (tail_start_m * tile_size_n)
246 + j * acc_tail_blk_elems);
247 mma_core<tail_size_m, block_size_n, block_size_k>(
248 dst_block, dst_block, b_block, a_block);
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
Workaround for ESIMD vector(1D) ref type.
Definition base_types.hpp:187
#define __REF__
Workaround for ESIMD reference usage.
Definition base_types.hpp:177
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:457
mma_engine
Definition common.hpp:225
gpu_arch
Definition common.hpp:73
Definition arch_config.hpp:72
typename matA_t::dtype dtype_a
Definition fma_xe.hpp:36
static __XETLA_API void mma(matDst_t &dst, matSrc_t &src, matB_t &b, matA_t &a)
Definition fma_xe.hpp:156
matAcc_src_t_ matSrc_t
Definition fma_xe.hpp:34
typename arch_attr_t< arch_tag_ >::template register_attr<> register_attr
Definition fma_xe.hpp:42
typename matSrc_t::dtype dtype_src
Definition fma_xe.hpp:38
typename matB_t::dtype dtype_b
Definition fma_xe.hpp:37
matAcc_dst_t_ matDst_t
Definition fma_xe.hpp:35
typename matDst_t::dtype dtype_dst
Definition fma_xe.hpp:39
matA_t_ matA_t
Definition fma_xe.hpp:32
matB_t_ matB_t
Definition fma_xe.hpp:33
static __XETLA_API void mma_core(xetla_vector_ref< dtype_dst, blk_m *blk_n > __REF__ dst, xetla_vector_ref< dtype_src, blk_m *blk_n > __REF__ src, xetla_vector_ref< dtype_b, blk_k *blk_n > __REF__ b_block, xetla_vector_ref< dtype_a, blk_m *blk_k > __REF__ a_block)
Definition fma_xe.hpp:93
Is the xetla tile mma operation definition API.
Definition api.hpp:36