22#include "subgroup/subgroup.hpp"
27 template <
typename dtype,
int vec_len>
30 return vec_data - data;
35 template <
typename dtype,
int vec_len>
38 return vec_data / data;
42template <
typename op,
typename matAcc_t>
45 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
46 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
47 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
48 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
49 static constexpr uint32_t block_elems = matAcc_t::block_elems;
50 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
51 using dtype =
typename matAcc_t::dtype;
53 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
55 for (uint32_t j = 0; j < num_block_x; j++) {
56 auto acc_reg = (matAcc.reg)
57 .xetla_select<block_elems, 1>(
58 (i * num_block_x + j) * block_elems);
60 = acc_reg.xetla_format<dtype, block_size_y, block_size_x>();
62 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
63 acc_reg_2d.row(row_i) = op::template func<dtype, block_size_x>(
64 acc_reg_2d.row(row_i), data[block_size_y * i + row_i]);
69 if constexpr ((tile_size_y % block_size_y) != 0) {
70 constexpr uint32_t tail_start_y
71 = tile_size_y / block_size_y * block_size_y;
72 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
73 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
75 for (uint32_t j = 0; j < num_block_x; j++) {
76 auto acc_reg = (matAcc.reg)
77 .xetla_select<tail_block_elems, 1>(
78 tail_start_y * tile_size_x
79 + j * tail_block_elems);
81 = acc_reg.xetla_format<dtype, tail_size_y, block_size_x>();
83 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
84 acc_reg_2d.row(row_i) = op::template func<dtype, block_size_x>(
85 acc_reg_2d.row(row_i), data[tail_start_y + row_i]);
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:457
void tile_broadcast_op(matAcc_t &matAcc, xetla_vector< typename matAcc_t::dtype, matAcc_t::tile_size_y > data)
Definition tile_broadcast_op.hpp:43
Definition tile_broadcast_op.hpp:34
static xetla_vector< dtype, vec_len > func(xetla_vector< dtype, vec_len > vec_data, dtype data)
Definition tile_broadcast_op.hpp:36
Definition tile_broadcast_op.hpp:26
static xetla_vector< dtype, vec_len > func(xetla_vector< dtype, vec_len > vec_data, dtype data)
Definition tile_broadcast_op.hpp:28