32template <
typename tile_op_t_,
typename tile_shape_,
typename mem_desc_c_t_,
35 mem_desc_c_t_, std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
41 static constexpr gpu_arch arch_tag = arch_tag_;
42 static constexpr uint32_t barrier_count = 0;
43 static constexpr uint32_t slm_size = mem_desc_c_t::is_local
44 ? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y
59 inline arguments_t(
typename tile_op_t::arguments_t tile_op_args_)
60 : tile_op_args(tile_op_args_) {}
65 : tile_op_args(args.tile_op_args) {}
67 inline arguments_t &
operator=(
const arguments_t &args) {
68 this->tile_op_args = args.tile_op_args;
75 inline void init(
typename tile_op_t::arguments_t tile_op_args_) {
76 tile_op_args = tile_op_args_;
82 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
83 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
84 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
85 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
86 using dtype_c =
typename mem_desc_c_t::dtype;
87 static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout;
88 static constexpr mem_space mem_space_c = mem_desc_c_t::space;
93 int32_t sg_idx = g.
get_id() % wg_size_x;
94 int32_t sg_idy = g.
get_id() / wg_size_x;
95 int32_t tile_offset_n = sg_idx * sg_tile_n;
96 int32_t tile_offset_m = sg_idy * sg_tile_m;
97 mem_desc_c.update_coord(tile_offset_n, tile_offset_m);
115 template <
typename matAcc_t>
118 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
119 using mat_tile_desc =
typename matAcc_t::tile_desc;
122 mat_tile_desc, msg_type_c, arch_tag>;
123 update_sg_tile_tdesc(g, mem_desc_c);
125 tile_op(matAcc, mem_desc_c.coord, args.tile_op_args, slm_base,
128 matC_payload_t matC_payload(mem_desc_c);
130 subgroup::tile_store<cache_hint::streaming, cache_hint::write_back>(
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_c_t mem_desc_c, arguments_t args={}, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Default epilogue.
Definition tile_op_xe.hpp:116
tile_shape_ tile_shape
Definition tile_op_xe.hpp:39
mem_desc_c_t_ mem_desc_c_t
Definition tile_op_xe.hpp:40
typename epilogue_policy::tile_op_t tile_op_t
Definition tile_op_xe.hpp:38
Is the epilogue functor.
Definition api.hpp:35
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Epilogue policy for tile_op + store C fusion.
Definition epilogue_policy.hpp:40
tile_op_t_ tile_op_t
Definition epilogue_policy.hpp:41
tile_op_t::arguments_t tile_op_args
Is tile_op arguments, could be a single tile_op argument or chained_tile_op_args.
Definition tile_op_xe.hpp:51
arguments_t()=default
Constructs a new arguments t object.
void init(typename tile_op_t::arguments_t tile_op_args_)
Explicit initialization function.
Definition tile_op_xe.hpp:75
arguments_t(const arguments_t &args)
Definition tile_op_xe.hpp:64
arguments_t(typename tile_op_t::arguments_t tile_op_args_)
Constructs a new arguments t object.
Definition tile_op_xe.hpp:59
arguments_t & operator=(const arguments_t &args)
Definition tile_op_xe.hpp:67
Is to illustrate the memory information.
Definition api.hpp:44
Is a struct contains some register file.
Definition api.hpp:99
Define a workgroup scope for a specific problem shape.
Definition work_group.hpp:34
__XETLA_API uint32_t get_id()
Definition work_group.hpp:41