32template <
typename tile_shape_,
typename epilogue_t_,
typename mem_desc_d_t_,
33 typename mem_desc_atomic_sync_t_>
45 static constexpr uint32_t
wg_tile_m = tile_shape::wg_tile_size_y;
46 static constexpr uint32_t
wg_tile_n = tile_shape::wg_tile_size_x;
47 static constexpr uint32_t
sg_tile_m = tile_shape::sg_tile_size_y;
48 static constexpr uint32_t
sg_tile_n = tile_shape::sg_tile_size_x;
49 static constexpr uint32_t
wg_size_x = tile_shape::wg_size_x;
50 static constexpr uint32_t
wg_size_y = tile_shape::wg_size_y;
60 using dtype_d =
typename mem_desc_d_t::dtype;
61 using dtype_flag =
typename mem_desc_atomic_sync_t::dtype;
78 int32_t tile_offset_n = sg_idx *
sg_tile_n;
79 int32_t tile_offset_m = sg_idy *
sg_tile_m;
80 mem_desc_d.update_coord(tile_offset_n, tile_offset_m);
93 template <
typename matAcc_t>
97 int first_group_idx,
bool tile_finished,
bool tile_started,
99 uint32_t nbarrier_base = 0) {
101 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
102 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
103 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
104 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
112 uint32_t nbarrier_id = nbarrier_base;
121 = xetla_vector_gen<uint32_t, 16>(0, 1);
124 flag_offsets = flag_offsets *
sizeof(
dtype_flag);
125 int32_t sg_id = g.get_id();
126 dtype_flag *flag_pointer = mem_desc_atomic_sync.base.base;
129 if (!tile_finished) {
132 matD_atomic_payload_t matD_atomic_payload(mem_desc_d);
149 (uint64_t)flag_pointer, flag_offsets, signal_val, pred);
158 uint32_t num_peers = group_idx - first_group_idx;
172 while (ret_val[0] != num_peers) {
177 flag_pointer, flag_offsets, old_val, zero_val,
188 mem_desc_d.base, mem_desc_d.shape);
190 residual_op(matAcc, mem_desc_d.coord, residual_args);
195 epilogue(g, matAcc, mem_desc_c, epilogue_args, slm_base,
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
__XETLA_API xetla_vector< T, N > xetla_atomic_global(T *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred)
Stateless scattered atomic (0 src).
Definition memory.hpp:371
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
@ tile
flush out to the local scope
mem_space
Definition common.hpp:77
@ iadd
Atomic signed int add of src1 from memory data and return the old value. see
@ cmpxchg
Atomic bit-compare src1_X and memory data and replace if equal with src1_Y. Returns the old value....
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Is the epilogue functor specialized for stream_k.
Definition stream_k_op_xe.hpp:34
xetla_nbarrier_t< N_SG, N_SG, arch_tag > nbarrier
Definition stream_k_op_xe.hpp:58
static constexpr uint32_t slm_size
Definition stream_k_op_xe.hpp:55
static __XETLA_API void update_sg_tile_tdesc(work_group_t &g, mem_desc_d_t &mem_desc_d)
Updates tile base descriptor based on the tid.
Definition stream_k_op_xe.hpp:74
static constexpr uint32_t wg_tile_m
Definition stream_k_op_xe.hpp:45
static constexpr uint32_t wg_size_x
Definition stream_k_op_xe.hpp:49
typename mem_desc_atomic_sync_t::dtype dtype_flag
Definition stream_k_op_xe.hpp:61
static constexpr uint32_t sg_tile_m
Definition stream_k_op_xe.hpp:47
static constexpr mem_layout mem_layout_d
Definition stream_k_op_xe.hpp:68
typename epilogue_t::arguments_t epilogue_args_t
Definition stream_k_op_xe.hpp:42
static constexpr msg_type msg_type_d_block2d
Definition stream_k_op_xe.hpp:70
mem_desc_atomic_sync_t_ mem_desc_atomic_sync_t
Definition stream_k_op_xe.hpp:40
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_c_t mem_desc_c, mem_desc_d_t mem_desc_d, mem_desc_atomic_sync_t mem_desc_atomic_sync, int group_idx, int first_group_idx, bool tile_finished, bool tile_started, epilogue_args_t epilogue_args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Epilogue for stream_k.
Definition stream_k_op_xe.hpp:94
static constexpr uint32_t N_SG
Definition stream_k_op_xe.hpp:56
static constexpr mem_space mem_space_d
Definition stream_k_op_xe.hpp:69
mem_desc_d_t_ mem_desc_d_t
Definition stream_k_op_xe.hpp:38
static constexpr uint32_t sg_tile_n
Definition stream_k_op_xe.hpp:48
tile_shape_ tile_shape
Definition stream_k_op_xe.hpp:41
static constexpr uint32_t wg_tile_n
Definition stream_k_op_xe.hpp:46
typename tile_shape::work_group_t work_group_t
Definition stream_k_op_xe.hpp:44
static constexpr gpu_arch arch_tag
Definition stream_k_op_xe.hpp:36
typename epilogue_t::mem_desc_c_t mem_desc_c_t
Definition stream_k_op_xe.hpp:39
static constexpr msg_type msg_type_d_atomic
Definition stream_k_op_xe.hpp:71
typename mem_desc_d_t::dtype dtype_d
Definition stream_k_op_xe.hpp:60
static constexpr uint32_t barrier_count
Definition stream_k_op_xe.hpp:53
static constexpr uint32_t wg_size_y
Definition stream_k_op_xe.hpp:50
typename residual_op_t::arguments_t residual_op_args_t
Definition stream_k_op_xe.hpp:66
epilogue_t_ epilogue_t
Definition stream_k_op_xe.hpp:37
Is the element-wise reduce op functor, specialized for stream_k dispatch Load partial sum from scratc...
Definition tile_op_functor.hpp:826
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76