31template <
typename dtype_in,
typename dtype_out,
typename dtype_acc>
53 typename dtype_out_,
typename dtype_acc_,
typename reduction_attr_>
63 [[maybe_unused]]
int start_n = 0,
64 [[maybe_unused]]
int start_m = 0) {}
65 template <
typename matAcc_t>
69 [[maybe_unused]]
int offset_m = 0) {}
72template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
73 typename reduction_attr_>
75 dtype_in_, dtype_out_, dtype_acc_, reduction_attr_,
gpu_arch::Xe> {
88 arguments_t *args,
int start_n = 0,
int start_m = 0) {
102 template <
typename matAcc_t>
104 static_assert(std::is_same<remove_const_t<dtype_acc>,
105 typename matAcc_t::dtype>::value,
106 "dtype_acc should match with matAcc");
107 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
108 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
109 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
110 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
111 static constexpr uint32_t num_elems = matAcc_t::tile_elems;
118 subgroup::msg_type_v<dgelu_tile_desc_t, mem_space::global>,
124 dgelu_w_in_t dgelu_w_in;
125 dgelu_w_in_payload_t dgelu_w_in_payload(w_load_base_desc);
128 = xetla_cvt<dtype_acc, dtype_in, num_elems>(dgelu_w_in.reg);
129 matAcc.reg = matAcc.reg * w;
130 dgelu_x_out_t dgelu_x_out;
131 dgelu_x_out_payload_t dgelu_x_out_payload(x_store_base_desc);
133 subgroup::tile_store<cache_hint::uncached>(
134 dgelu_x_out, dgelu_x_out_payload);
143 w_load_base_desc.update_coord(offset_n, offset_m);
144 x_store_base_desc.update_coord(offset_n, offset_m);
148template <
typename dtype_in_,
typename dtype_out_,
typename dtype_acc_,
149 typename reduction_attr_>
151 dtype_in_, dtype_out_, dtype_acc_, reduction_attr_,
gpu_arch::Xe> {
168 arguments_t *args,
int start_n = 0,
int start_m = 0) {
170 mask_load_base_desc.init({args->
mask_ptr},
180 template <
typename matAcc_t>
182 static_assert(std::is_same<remove_const_t<dtype_acc>,
183 typename matAcc_t::dtype>::value,
184 "dtype_acc should match with matAcc");
185 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
186 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
187 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
188 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
195 reduction_tile_desc_t,
196 subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
198 using dropout_bwd_out_t
202 reduction_tile_desc_t,
203 subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
205 if (dropout_prob != 0) {
207 mask_in_payload_t mask_in_payload(mask_load_base_desc);
210 matAcc.reg = drop_out<dtype_acc, tile_size_x * tile_size_y>(
211 matAcc.reg, mask_in.reg, dropout_scale_inv);
213 dropout_bwd_out_t dropout_bwd_out;
214 dropout_bwd_out_payload_t dropout_bwd_out_payload(
215 dropout_bwd_store_base_desc);
217 subgroup::tile_store<cache_hint::uncached>(
218 dropout_bwd_out, dropout_bwd_out_payload);
222 mask_load_base_desc.update_coord(offset_n, offset_m);
223 dropout_bwd_store_base_desc.update_coord(offset_n, offset_m);
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
reduction_fused_kind
Definition row_reduction_fused_op_api.hpp:28
dtype_in_ dtype_in
Definition row_reduction_fused_op_xe.hpp:57
dtype_out_ dtype_out
Definition row_reduction_fused_op_xe.hpp:58
dtype_acc_ dtype_acc
Definition row_reduction_fused_op_xe.hpp:59
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:62
__XETLA_API void update_tdesc(int offset_n=0, int offset_m=0)
Definition row_reduction_fused_op_xe.hpp:68
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc)
Definition row_reduction_fused_op_xe.hpp:66
float dropout_prob
Definition row_reduction_fused_op_xe.hpp:164
dtype_out_ dtype_out
Definition row_reduction_fused_op_xe.hpp:155
mem_desc_t< dtype_out, mem_layout::row_major, mem_space::global > dropout_bwd_store_base_desc
Definition row_reduction_fused_op_xe.hpp:163
dtype_in_ dtype_in
Definition row_reduction_fused_op_xe.hpp:154
__XETLA_API void update_tdesc(int offset_n=0, int offset_m=0)
Definition row_reduction_fused_op_xe.hpp:221
float dropout_scale_inv
Definition row_reduction_fused_op_xe.hpp:165
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc)
Definition row_reduction_fused_op_xe.hpp:181
mem_desc_t< dtype_mask, mem_layout::row_major, mem_space::global > mask_load_base_desc
Definition row_reduction_fused_op_xe.hpp:161
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:167
uint8_t dtype_mask
Definition row_reduction_fused_op_xe.hpp:157
dtype_acc_ dtype_acc
Definition row_reduction_fused_op_xe.hpp:156
mem_desc_t< dtype_in, mem_layout::row_major, mem_space::global > w_load_base_desc
Definition row_reduction_fused_op_xe.hpp:84
dtype_in_ dtype_in
Definition row_reduction_fused_op_xe.hpp:78
mem_desc_t< dtype_out, mem_layout::row_major, mem_space::global > x_store_base_desc
Definition row_reduction_fused_op_xe.hpp:86
__XETLA_API void update_tdesc(int offset_n=0, int offset_m=0)
Definition row_reduction_fused_op_xe.hpp:142
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:87
dtype_out_ dtype_out
Definition row_reduction_fused_op_xe.hpp:79
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc)
Definition row_reduction_fused_op_xe.hpp:103
dtype_acc_ dtype_acc
Definition row_reduction_fused_op_xe.hpp:80
Additional Ops that can be fused with row reduction processing flow.
Definition row_reduction_fused_op_api.hpp:47
Definition row_reduction_fused_op_xe.hpp:32
uint8_t * mask_ptr
Definition row_reduction_fused_op_xe.hpp:36
uint32_t mat_in_ld
Definition row_reduction_fused_op_xe.hpp:41
dtype_in * gelu_bwd_w_ptr
Definition row_reduction_fused_op_xe.hpp:33
float dropout_scale_inv
Definition row_reduction_fused_op_xe.hpp:38
uint32_t matrix_n
Definition row_reduction_fused_op_xe.hpp:40
dtype_out * dropout_bwd_ptr
Definition row_reduction_fused_op_xe.hpp:35
float dropout_prob
Definition row_reduction_fused_op_xe.hpp:37
uint32_t matrix_m
Definition row_reduction_fused_op_xe.hpp:39
uint32_t mat_out_ld
Definition row_reduction_fused_op_xe.hpp:42
dtype_out * gelu_bwd_x_ptr
Definition row_reduction_fused_op_xe.hpp:34
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99