38template <
typename dtype_in_,
typename dtype_out_,
typename dtype_compute_,
39 typename data_transformer_attr_,
mem_layout mem_layout_in_,
42 data_transformer_attr_, mem_layout_in_, need_fp8_op,
gpu_arch::
Xe> {
48 static constexpr mem_layout mem_layout_in = mem_layout_in_;
50 static constexpr bool is_col_major_in
53 static constexpr uint32_t wg_tile_m = data_transformer_attr::wg_tile_m;
54 static constexpr uint32_t wg_tile_n = data_transformer_attr::wg_tile_n;
55 static constexpr uint32_t sg_tile_m = data_transformer_attr::sg_tile_m;
56 static constexpr uint32_t sg_tile_n = data_transformer_attr::sg_tile_n;
58 static constexpr uint32_t tile_size_x = sg_tile_n;
59 static constexpr uint32_t tile_size_y = sg_tile_m;
61 static constexpr uint32_t wg_size_x
62 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
63 static constexpr uint32_t wg_size_y
64 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
68 static constexpr uint32_t max_load_height_in_elem
69 = load_store_attr::max_load_height_in_elem;
70 static constexpr uint32_t max_load_width_in_bytes
71 = load_store_attr::max_load_width_in_bytes;
72 static constexpr uint32_t max_store_width_in_bytes
73 = load_store_attr::max_store_width_in_bytes;
74 static constexpr uint32_t max_trans_block_width
75 = load_store_attr::max_trans_load_width_in_bytes /
sizeof(
dtype_in);
76 static constexpr uint32_t max_load_width_in_elem
77 = max_load_width_in_bytes /
sizeof(
dtype_in);
78 static constexpr uint32_t max_store_width_in_elem
79 = max_store_width_in_bytes /
sizeof(
dtype_out);
82 static constexpr uint32_t load_size_x
84 max_load_width_in_elem>::value;
85 static_assert(load_size_x >= 8,
86 "if block_size_x less than 8, the efficiency will be low. Please "
87 "choose another tile_size_x");
88 static constexpr uint32_t st_size_x = max_store_width_in_elem > tile_size_x
91 max_store_width_in_elem>::value;
92 static_assert(st_size_x >= 8,
93 "if st_block_size_x less than 8, the efficiency will be "
95 static constexpr uint32_t block_size_x
98 static constexpr uint32_t block_size_y_limit
99 = is_col_major_in ? max_trans_block_width : max_load_height_in_elem;
101 static constexpr uint32_t block_size_y = block_size_y_limit > tile_size_y
103 : block_size_y_limit;
108 tile_size_y, block_size_x, block_size_y, in_reg_layout>;
113 subgroup::msg_type_v<global_ld_tile_desc_t, mem_space::global>,
152 struct get_barrier_count {
153 static constexpr uint32_t count
154 = (wg_size_x * wg_size_y > 1) ? wg_size_x * wg_size_y : 0;
159 struct get_slm_size {
160 static constexpr uint32_t size = (wg_size_x * wg_size_y > 1)
171 int tid_x = item.get_local_id(2);
172 int tid_y = item.get_local_id(1);
173 uint32_t sg_id = item.get_local_linear_id();
182 int global_ld_start_x;
183 int global_ld_start_y;
186 global_ld_start_x = args->wg_ld_start_x + tid_x * sg_tile_n;
187 global_ld_start_y = args->wg_ld_start_y + tid_y * sg_tile_m;
189 global_ld_start_x = args->wg_ld_start_x + tid_y * sg_tile_m;
190 global_ld_start_y = args->wg_ld_start_y + tid_x * sg_tile_n;
193 int global_st_start_x = args->wg_st_start_x + tid_x * sg_tile_n;
194 int global_st_start_y = args->wg_st_start_y + tid_y * sg_tile_m;
197 global_ld_payload.init(args->mat_in_ptr, args->matrix_n,
198 args->matrix_m, args->matrix_in_ld, global_ld_start_x,
201 global_ld_payload.init(args->mat_in_ptr, args->matrix_m,
202 args->matrix_n, args->matrix_in_ld, global_ld_start_x,
206 global_st_payload.init(args->mat_out_ptr, args->matrix_n,
207 args->matrix_m, args->matrix_out_ld, global_st_start_x,
212 if constexpr (need_fp8_op) {
215 static constexpr uint32_t simd = 16;
223 mat_global_compute.
reg
229 wg_reduce.init(sg_id, 0, 0);
232 global_compute_t::tile_desc::tile_elems>(
233 mat_global_compute.
reg);
236 = wg_reduce(mat_global_compute.
reg);
243 = xetla_vector_gen<uint32_t, simd>(0, 1);
253 subgroup::tile_store<cache_hint::uncached>(
254 mat_global_st, global_st_payload);
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< T0, SZ > xetla_abs(xetla_vector< T1, SZ > src0)
Get absolute value (vector version)
Definition math_general.hpp:39
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
Definition limitation.hpp:734
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reg_layout
tile layout in register linear: linear layout with one tile tiled: 2d block stacked in raster order v...
Definition common.hpp:209
@ fmax
Atomic store the float max of src1 and memory data and return the old value. see
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
This is the group reduction.
Definition reduction_api.hpp:36
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102