22#include "group/reduction/reduction_api.hpp"
26template <
typename T, uint32_t SZ, uint32_t N,
reduce_op Op, uint32_t N_SG,
42 subgroup::msg_type_v<local_ld_tile_desc, mem_space::local>,
49 uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_) {
54 inline void init(uint32_t sg_id_ = 0, uint32_t nbarrier_id = 0,
55 uint32_t slm_base_ = 0) {
60 inline void set_slm_base(uint32_t slm_base_ = 0) { slm_base = slm_base_; }
68 local_st_payload.init(slm_base, N_SG * N, 1, N_SG * N, sg_id * N, 0);
70 xetla_fence<memory_kind::shared_local>();
73 if constexpr (is_all_reduce) {
76 slm_base, N_SG * N, 1, N_SG * N, 0, 0);
78 ret = recur_row_reduce<Op, T, N, N_SG>(local_ld.
reg);
83 local_ld_payload.init(slm_base, N_SG * N, 1, N_SG * N, 0, 0);
85 ret = recur_row_reduce<Op, T, N, N_SG>(local_ld.
reg);
92template <
typename T, u
int32_t SZ, u
int32_t N, reduce_op Op,
bool is_all_reduce>
96 [[maybe_unused]] uint32_t nbarrier_id,
97 [[maybe_unused]] uint32_t slm_base_) {}
98 inline void init([[maybe_unused]] uint32_t sg_id_ = 0,
99 [[maybe_unused]] uint32_t nbarrier_id = 0,
100 [[maybe_unused]] uint32_t slm_base_ = 0) {}
104 auto buffer_2d = buffer.xetla_format<T, N, SZ>();
107 for (uint32_t i = 0; i < N; i++) {
108 ret[i] = xetla_reduce<T, T, SZ, Op>(buffer_2d.row(i));
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
void init(uint32_t sg_id_=0, uint32_t nbarrier_id=0, uint32_t slm_base_=0)
Definition reduction_xe.hpp:98
group_reduce_t(uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_)
Definition reduction_xe.hpp:95
void set_slm_base(uint32_t slm_base_=0)
Definition reduction_xe.hpp:101
KERNEL_FUNC xetla_vector< T, N > operator()(xetla_vector< T, N *SZ > buffer)
Definition reduction_xe.hpp:102
KERNEL_FUNC xetla_vector< T, N > operator()(xetla_vector< T, N *SZ > buffer)
Definition reduction_xe.hpp:62
void set_slm_base(uint32_t slm_base_=0)
Definition reduction_xe.hpp:60
uint32_t slm_base
Definition reduction_xe.hpp:31
void init(uint32_t sg_id_=0, uint32_t nbarrier_id=0, uint32_t slm_base_=0)
Definition reduction_xe.hpp:54
uint32_t sg_id
Definition reduction_xe.hpp:32
xetla_nbarrier_t< N_SG, N_SG, gpu_arch::Xe > nbarrier
Definition reduction_xe.hpp:30
group_reduce_t(uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_)
Definition reduction_xe.hpp:48
This is the group reduction.
Definition reduction_api.hpp:36
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76