26template <
typename dtype_acc,
typename dtype_out, uint32_t row_size,
27 uint32_t wg_size_x, uint32_t wg_size_y, uint32_t max_simd_len>
30 static constexpr uint32_t block_size_x
32 static_assert(block_size_x >= 8,
33 "if block_size_x is less than 8, the efficiency will be low. "
34 "Please choose another tile_size_x");
35 static constexpr uint32_t num_block_x = row_size / block_size_x;
36 static_assert((num_block_x < wg_size_y) || (num_block_x % wg_size_y == 0),
37 "num_block_x should be less than wg_size_y or num_block_x should "
38 "be a multiple of wg_size_y");
39 static constexpr uint32_t num_block_per_thd
40 = (num_block_x < wg_size_y) ? 1 : num_block_x / wg_size_y;
41 static constexpr uint32_t cooperative_thd_num
42 = (num_block_x < wg_size_y) ? num_block_x : wg_size_y;
43 static constexpr uint32_t local_tile_size_x
44 = num_block_per_thd * block_size_x;
51 subgroup::msg_type_v<local_st_tile_desc_t, mem_space::local>,
59 subgroup::msg_type_v<local_ld_tile_desc_t, mem_space::local>,
79 inline void init(uint32_t sg_idx_ = 0, uint32_t sg_idy_ = 0,
80 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
85 local_st_payload.init(slm_base, row_size * wg_size_x, wg_size_y,
86 row_size * wg_size_x, row_size * sg_idx, sg_idy);
87 local_ld_payload.init(slm_base, row_size * wg_size_x, wg_size_y,
89 row_size * sg_idx + local_tile_size_x * sg_idy, 0);
93 uint32_t st_height, uint32_t st_pitch,
int start_n_base,
95 local_st.
reg = buffer;
97 xetla_fence<memory_kind::shared_local>();
100 if (sg_idy < cooperative_thd_num) {
104 st_pitch, start_n_base + local_tile_size_x * sg_idy,
107 dtype_acc, 0>(local_ld);
108 subgroup::tile_store<cache_hint::uncached>(
109 global_st, global_st_payload);
116template <
typename dtype_acc,
typename dtype_out, uint32_t row_size,
117 uint32_t wg_size_x, uint32_t max_simd_len>
120 static constexpr uint32_t block_size_x
132 inline void init([[maybe_unused]] uint32_t sg_idx_ = 0,
133 [[maybe_unused]] uint32_t sg_idy_ = 0,
134 [[maybe_unused]] uint32_t slm_base = 0,
135 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
138 uint32_t st_height, uint32_t st_pitch,
int start_n_base,
142 global_st.
reg = xetla_cvt<dtype_out, dtype_acc, row_size>(buffer);
143 global_st_payload.init(
144 ptr, st_width, st_height, st_pitch, start_n_base, start_m_base);
145 subgroup::tile_store<cache_hint::uncached>(
146 global_st, global_st_payload);
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
void init(uint32_t sg_idx_=0, uint32_t sg_idy_=0, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition row_reduce_store_xe.hpp:132
KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width, uint32_t st_height, uint32_t st_pitch, int start_n_base, int start_m_base, xetla_vector< dtype_acc, row_size > buffer)
Definition row_reduce_store_xe.hpp:137
uint32_t sg_idx
Definition row_reduce_store_xe.hpp:77
local_st_t local_st
Definition row_reduce_store_xe.hpp:73
xetla_nbarrier_t< wg_size_y, wg_size_y, gpu_arch::Xe > nbarrier
Definition row_reduce_store_xe.hpp:72
uint32_t sg_idy
Definition row_reduce_store_xe.hpp:78
void init(uint32_t sg_idx_=0, uint32_t sg_idy_=0, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition row_reduce_store_xe.hpp:79
local_ld_payload_t local_ld_payload
Definition row_reduce_store_xe.hpp:76
local_ld_t local_ld
Definition row_reduce_store_xe.hpp:75
KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width, uint32_t st_height, uint32_t st_pitch, int start_n_base, int start_m_base, xetla_vector< dtype_acc, row_size > buffer)
Definition row_reduce_store_xe.hpp:92
local_st_payload_t local_st_payload
Definition row_reduce_store_xe.hpp:74
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
Definition memory_descriptor.hpp:139
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76