22#include "common/common.hpp"
36template <
typename dtype_mask_, uint32_t wg_tile_n_, uint32_t wg_tile_m_,
37 uint32_t sg_tile_n_, uint32_t sg_tile_m_, uint32_t random_simd_ = 16,
67 = load_store_attr::max_store_width_in_bytes;
71 = load_store_attr::max_store_height_in_elem;
81 "if block_size_x less than 8, the efficiency will be low. Please "
82 "choose another tile_size_x");
106 uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy,
107 uint32_t linear_idx) {
112 uint32_t threshold = uint32_t(args->
dropout_prob *
float(4294967296));
118 static constexpr uint32_t random_len = 4 *
random_simd;
120 rand_gen.
init(args->
rand_seed, linear_idx, rand_offset_ptr_v[0]);
124 for (
int i = 0; i <
tile_size / random_len; i++) {
125 auto mask_sub = mask.xetla_select<random_len, 1>(i * random_len);
128 mask_sub.xetla_merge(1, 0, mask_flag);
130 if constexpr (
tile_size % random_len != 0) {
131 constexpr uint32_t remain_len =
tile_size % random_len;
132 constexpr uint32_t remain_start
134 auto mask_sub = mask.xetla_select<remain_len, 1>(remain_start);
138 mask_sub.xetla_merge(
139 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
142 subgroup::tile_store<cache_hint::uncached>(mask_out, mask_out_payload);
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
Definition arch_config.hpp:72
Definition dropout_mask_gen.hpp:54
uint32_t mask_ld
Definition dropout_mask_gen.hpp:58
uint32_t matrix_m
Definition dropout_mask_gen.hpp:56
uint64_t * rand_offset_ptr
Definition dropout_mask_gen.hpp:60
uint64_t rand_seed
Definition dropout_mask_gen.hpp:59
uint32_t matrix_n
Definition dropout_mask_gen.hpp:57
dtype_mask * mask_ptr
Definition dropout_mask_gen.hpp:55
float dropout_prob
Definition dropout_mask_gen.hpp:61
Definition dropout_mask_gen.hpp:39
static constexpr uint32_t random_simd
Definition dropout_mask_gen.hpp:45
static constexpr uint32_t block_size_x
Definition dropout_mask_gen.hpp:76
static constexpr uint32_t tile_size_x
Definition dropout_mask_gen.hpp:72
static constexpr uint32_t sg_tile_m
Definition dropout_mask_gen.hpp:44
static constexpr uint32_t tile_size_y
Definition dropout_mask_gen.hpp:73
typename arch_attr_t< arch_ >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition dropout_mask_gen.hpp:65
dtype_mask_ dtype_mask
Definition dropout_mask_gen.hpp:40
static constexpr uint32_t max_store_width_in_bytes
Definition dropout_mask_gen.hpp:67
static constexpr uint32_t tile_size
Definition dropout_mask_gen.hpp:95
static constexpr uint32_t block_size_y
Definition dropout_mask_gen.hpp:84
static constexpr uint32_t max_store_height_in_elem
Definition dropout_mask_gen.hpp:71
subgroup::tile_desc_t< tile_size_x, tile_size_y, block_size_x, block_size_y, reg_layout::tiled > mask_out_tile_desc_t
Definition dropout_mask_gen.hpp:88
static constexpr uint32_t wg_tile_n
Definition dropout_mask_gen.hpp:41
static constexpr uint32_t max_store_width_in_elem
Definition dropout_mask_gen.hpp:69
static constexpr uint32_t sg_tile_n
Definition dropout_mask_gen.hpp:43
static constexpr uint32_t wg_size_x
Definition dropout_mask_gen.hpp:48
__XETLA_API KERNEL_FUNC void operator()(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t linear_idx)
Definition dropout_mask_gen.hpp:105
static constexpr uint32_t wg_size_y
Definition dropout_mask_gen.hpp:50
static constexpr uint32_t wg_tile_m
Definition dropout_mask_gen.hpp:42
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38