40template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
41 typename reduction_attr_,
typename fused_op_t_>
51 static constexpr uint32_t wg_tile_m = reduction_attr::wg_tile_m;
52 static constexpr uint32_t wg_tile_n = reduction_attr::wg_tile_n;
53 static constexpr uint32_t sg_tile_m = reduction_attr::sg_tile_m;
54 static constexpr uint32_t sg_tile_n = reduction_attr::sg_tile_n;
55 static constexpr bool is_dynamic_job = reduction_attr::is_dynamic_job;
56 static constexpr uint32_t wg_size_x
57 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
58 static constexpr uint32_t wg_size_y
59 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
61 static constexpr bool use_dynamic_job = is_dynamic_job && (wg_size_y > 1);
64 static constexpr uint32_t max_load_height_in_elem
65 = load_store_attr::max_load_height_in_elem;
66 static constexpr uint32_t max_load_width_in_bytes
67 = load_store_attr::max_load_width_in_bytes;
68 static constexpr uint32_t max_store_width_in_bytes
69 = load_store_attr::max_store_width_in_bytes;
70 static constexpr uint32_t max_load_width_in_elem
71 = max_load_width_in_bytes /
sizeof(
dtype_in);
72 static constexpr uint32_t max_store_width_in_elem
73 = max_store_width_in_bytes /
sizeof(
dtype_out);
75 static constexpr uint32_t tile_size_x = sg_tile_n;
76 static constexpr uint32_t tile_size_y = sg_tile_m;
78 static constexpr uint32_t max_simd_len = max_store_width_in_elem;
81 static constexpr uint32_t block_size_x
82 = max_load_width_in_elem > tile_size_x
85 max_load_width_in_elem>::value;
86 static_assert(block_size_x >= 8,
87 "if block_size_x less than 8, the efficiency will be low. Please "
88 "choose another tile_size_x");
89 static constexpr uint32_t block_size_y
90 = max_load_height_in_elem > tile_size_y ? tile_size_y
91 : max_load_height_in_elem;
93 static constexpr uint32_t
SIMD = 16;
101 subgroup::msg_type_v<global_ld_tile_desc_t, mem_space::global>,
108 dtype_out, sg_tile_n, wg_size_x, wg_size_y, max_simd_len>;
122 struct get_barrier_count {
123 static constexpr uint32_t count = (wg_size_y > 1) ? wg_size_x : 0;
126 static constexpr uint32_t counter_size
127 = use_dynamic_job ?
SIMD *
sizeof(int) * wg_size_x : 0;
128 static constexpr uint32_t row_buffer_size = (wg_size_y > 1)
129 ? tile_size_x * wg_size_x * wg_size_y *
sizeof(
dtype_acc)
134 struct get_slm_size {
135 static constexpr uint32_t size = row_buffer_size + counter_size;
149 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
151 g.init(item.get_local_linear_id());
152 int sg_idx = g.get_id() % wg_size_x;
153 int sg_idy = g.get_id() / wg_size_x;
155 int global_start_x_in
156 = item.get_group(2) * wg_tile_n + sg_idx * sg_tile_n;
157 int global_start_y_in = sg_idy * sg_tile_m;
161 if constexpr (use_dynamic_job) {
164 slm_base + row_buffer_size + sg_idx *
SIMD *
sizeof(
int));
169 xetla_store_local<int, 1, data_size::default_size, SIMD>(
170 offsets, init, pred);
171 xetla_fence<memory_kind::shared_local>();
178 fused_op_args, global_start_x_in, global_start_y_in);
180 args->matrix_n, args->matrix_m, args->mat_in_ld,
181 global_start_x_in, global_start_y_in);
183 if constexpr (use_dynamic_job) {
187 slm_base + row_buffer_size + sg_idx *
SIMD *
sizeof(
int));
190 while (job_id * tile_size_y < args->matrix_m) {
192 = xetla_atomic_local<atomic_op::iinc, int, SIMD>(
196 subgroup::elemwise_cvt<matAcc_t, global_ld_t>(
197 matAcc, mat_global_ld);
201 mat_global_ld_payload
202 .template update_tdesc<tdesc_update_dir::y_dir>(
203 (next_job[0] - job_id) * tile_size_y);
204 fused_op.update_tdesc(0, (next_job[0] - job_id) * tile_size_y);
205 job_id = next_job[0];
208 for (
int job_id = sg_idy; job_id * tile_size_y < args->matrix_m;
209 job_id += wg_size_y) {
212 subgroup::elemwise_cvt<matAcc_t, global_ld_t>(
213 matAcc, mat_global_ld);
217 fused_op.update_tdesc(0, wg_size_y * tile_size_y);
218 mat_global_ld_payload
219 .template update_tdesc<tdesc_update_dir::y_dir>(
220 wg_size_y * tile_size_y);
225 uint32_t slm_row_reduce_base = slm_base;
226 uint32_t nbarrier_row_reduce_base = nbarrier_base;
227 row_reduce_store.init(
228 sg_idx, sg_idy, slm_row_reduce_base, nbarrier_row_reduce_base);
229 row_reduce_store(args->mat_out_ptr, args->matrix_n, 1, args->matrix_n,
230 global_start_x_in, 0, mat_buffer.
reg);
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
Definition arch_config.hpp:72
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition row_reduction_xe.hpp:63
dtype_out_ dtype_out
Definition row_reduction_xe.hpp:45
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, fused_op_arguments_t *fused_op_args=nullptr, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for row reduction.
Definition row_reduction_xe.hpp:147
fused_op_t_ fused_op_t
Definition row_reduction_xe.hpp:48
reduction_attr_ reduction_attr
Definition row_reduction_xe.hpp:47
work_group_t< wg_size_x *wg_size_y > work_group_t
Definition row_reduction_xe.hpp:60
typename fused_op_t::arguments_t fused_op_arguments_t
Definition row_reduction_xe.hpp:49
dtype_acc_ dtype_acc
Definition row_reduction_xe.hpp:46
dtype_in_ dtype_in
Definition row_reduction_xe.hpp:44
uint32_t matrix_n
Definition row_reduction_xe.hpp:116
dtype_out * mat_out_ptr
Definition row_reduction_xe.hpp:114
dtype_in * mat_in_ptr
Definition row_reduction_xe.hpp:113
uint32_t matrix_m
Definition row_reduction_xe.hpp:115
uint32_t mat_in_ld
Definition row_reduction_xe.hpp:117
Is the row_reduction functor.
Definition api.hpp:39
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76