25template <
typename dtype_in_,
typename dtype_out_,
typename tile_shape_,
27 uint32_t thread_num_, uint32_t softmax_size_>
36 static constexpr uint32_t
sg_tile_m = tile_shape::sg_tile_size_y;
37 static constexpr uint32_t
sg_tile_n = tile_shape::sg_tile_size_x;
38 static constexpr uint32_t
wg_size_x = tile_shape::wg_size_x;
39 static constexpr uint32_t
wg_size_y = tile_shape::wg_size_y;
43 static constexpr uint32_t
SIMD = SIMD_;
59 subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>,
67 subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>,
89 uint32_t local_offset_y =
block_height * item.get_local_linear_id();
107 for (uint32_t row = 0; row < inner_loop_count; ++row) {
108 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
109 softmax_load, softmax_load_payload);
110 softmax_load_payload.template update_tdesc<tdesc_update_dir::y_dir>(
116 float xmax = hmax<float, float, softmax_size>(row_data_32);
120 row_data_32 = exp(row_data_32);
121 float exp_sum = sum<float, float, softmax_size>(row_data_32);
124 row_data_32 /= exp_sum;
127 = xetla_cvt<dtype_out, float, softmax_size>(row_data_32);
129 tile_store(softmax_store, softmax_store_payload);
130 softmax_store_payload
131 .template update_tdesc<tdesc_update_dir::y_dir>(
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
Definition limitation.hpp:457
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
Definition softmax.hpp:70
uint32_t data_in_base
Definition softmax.hpp:72
dtype_in * data_in_ptr
Definition softmax.hpp:76
dtype_out * data_out_ptr
Definition softmax.hpp:78
uint32_t data_out_base
Definition softmax.hpp:74
Definition softmax.hpp:28
dtype_out_ dtype_out
Definition softmax.hpp:30
static constexpr uint32_t wg_size_x
Definition softmax.hpp:38
dtype_in_ dtype_in
Definition softmax.hpp:29
tile_shape_ tile_shape
Definition softmax.hpp:31
static constexpr uint32_t block_height
Definition softmax.hpp:46
static constexpr mem_space mem_space_out
Definition softmax.hpp:34
subgroup::tile_desc_t< SIMD, block_height, SIMD, block_height, reg_layout::tiled > softmax_tile_desc_t
Definition softmax.hpp:54
static constexpr uint32_t sg_tile_n
Definition softmax.hpp:37
static constexpr uint32_t wg_tile_m
Definition softmax.hpp:40
static constexpr uint32_t SIMD
Definition softmax.hpp:43
static constexpr uint32_t thread_num
Definition softmax.hpp:44
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, arguments_t *args)
Definition softmax.hpp:81
static constexpr uint32_t sg_tile_m
Definition softmax.hpp:36
static constexpr uint32_t wg_size_y
Definition softmax.hpp:39
static constexpr uint32_t softmax_size
Definition softmax.hpp:45
static constexpr mem_space mem_space_in
Definition softmax.hpp:33
static constexpr uint32_t wg_tile_n
Definition softmax.hpp:41