22#include "group/reduction/reduction.hpp"
29template <
typename dtype_in_,
typename dtype_acc_,
typename tile_shape_>
42 using shape_t =
typename mem_desc_in_t::shape_t;
43 using coord_t =
typename mem_desc_in_t::coord_t;
44 using base_t =
typename mem_desc_in_t::base_t;
45 using work_group_t =
typename tile_shape::work_group_t;
46 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
47 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
48 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
49 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
61 :
shape(shape_), base(base_), sqrt_dk_inv(sqrt_dk_inv_) {}
63 struct get_barrier_count {
64 static constexpr uint32_t count = (wg_size_x > 1) ? wg_size_y : 0;
68 static constexpr uint32_t size = (wg_size_x > 1)
69 ? wg_size_y * wg_size_x * sg_tile_m *
sizeof(
dtype_acc)
73 template <
typename matAcc_t>
75 coord_t coord,
const arguments_t &args, uint32_t slm_base = 0,
76 uint32_t nbarrier_base = 0) {
77 static_assert(std::is_same<typename matAcc_t::dtype, dtype_acc>::value,
78 "matAcc dtype should match with dtype_acc");
80 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
81 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
82 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
83 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
84 static_assert((sg_tile_m == tile_size_y) && (sg_tile_n == tile_size_x),
85 "tile size should match");
91 subgroup::msg_type_v<mat_in_tile_desc_t, mem_desc_in_t::space>,
94 int32_t sg_idx = g.get_id() % wg_size_x;
95 int32_t sg_idy = g.get_id() / wg_size_x;
96 int32_t tile_offset_n = sg_idx * sg_tile_n;
97 int32_t tile_offset_m = sg_idy * sg_tile_m;
98 coord.x += tile_offset_n;
99 coord.y += tile_offset_m;
100 uint32_t nbarrier_id = nbarrier_base + sg_idy;
101 uint32_t slm_base_addr
102 = slm_base + sg_idy * wg_size_x * sg_tile_m *
sizeof(
dtype_acc);
106 mat_in_payload_t mat_in_payload(mem_desc_in);
107 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
108 mat_in, mat_in_payload);
111 matAcc.reg = matAcc.reg * mat_in_acc.reg;
117 subgroup::tile_broadcast_op<subgroup::tile_minus, matAcc_t>(
119 matAcc.reg = matAcc.reg * mat_in_acc.reg * args.sqrt_dk_inv;
tile_shape_ tile_shape
Definition softmax_bwd_xe.hpp:34
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, coord_t coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition softmax_bwd_xe.hpp:74
dtype_in_ dtype_in
Definition softmax_bwd_xe.hpp:35
dtype_acc_ dtype_acc
Definition softmax_bwd_xe.hpp:36
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
This is the group reduction.
Definition reduction_api.hpp:36
Definition softmax_policy.hpp:31
base_t base
Definition softmax_bwd_xe.hpp:57
dtype_acc sqrt_dk_inv
Definition softmax_bwd_xe.hpp:58
shape_t shape
Definition softmax_bwd_xe.hpp:56
arguments_t(base_t base_, shape_t shape_, dtype_acc sqrt_dk_inv_)
Definition softmax_bwd_xe.hpp:60
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99