29template <u
int32_t SIMD = 16, u
int32_t round = 7>
43 auto key_2d =
key.xetla_format<uint32_t, 2,
SIMD>();
44 key_2d.row(0) = uint32_t(seed_v.xetla_format<uint32_t>()[0]);
45 key_2d.row(1) = uint32_t(seed_v.xetla_format<uint32_t>()[1]);
48 = xetla_vector_gen<uint32_t, SIMD>(0, 1);
49 auto counter_2d =
counter.xetla_format<uint32_t, 4,
SIMD>();
50 counter_2d.row(0) = uint32_t(offset_v.xetla_format<uint32_t>()[0]);
51 counter_2d.row(1) = uint32_t(offset_v.xetla_format<uint32_t>()[1]);
52 counter_2d.row(2) = uint32_t(subseq_v.xetla_format<uint32_t>()[0]);
53 counter_2d.row(3) = uint32_t(subseq_v.xetla_format<uint32_t>()[1]);
54 counter_2d.row(2) += channel_id;
60 auto key_2d_ = key_.xetla_format<uint32_t, 2,
SIMD>();
63 for (uint32_t i = 0; i < round; i++) {
64 counter_ = single_round(counter_, key_);
78 auto ret_2d = ret.xetla_format<uint32_t, 4,
SIMD>();
79 auto key_2d_ = key_.xetla_format<uint32_t, 2,
SIMD>();
80 auto counter_2d_ = counter_.xetla_format<uint32_t, 4,
SIMD>();
85 = xetla_imul<uint32_t, uint32_t, uint32_t, SIMD>(
86 res0_lo.xetla_format<uint32_t>(), counter_2d_.row(0),
89 = xetla_imul<uint32_t, uint32_t, uint32_t, SIMD>(
90 res1_lo.xetla_format<uint32_t>(), counter_2d_.row(2),
93 ret_2d.row(0) = res1_hi ^ counter_2d_.row(1) ^ key_2d_.row(0);
94 ret_2d.row(1) = res1_lo;
95 ret_2d.row(2) = res0_hi ^ counter_2d_.row(3) ^ key_2d_.row(1);
96 ret_2d.row(3) = res0_lo;
102 auto counter_2d =
counter.xetla_format<uint32_t, 4,
SIMD>();
103 xetla_vector<uint32_t, SIMD> carry;
105 counter_2d.row(0) = xetla_add_c<uint32_t, SIMD>(
106 counter_2d.row(0), 1, carry.xetla_format<uint32_t>());
107 counter_2d.row(1) = xetla_add_c<uint32_t, SIMD>(
108 counter_2d.row(1), carry, carry.xetla_format<uint32_t>());
109 counter_2d.row(2) = xetla_add_c<uint32_t, SIMD>(
110 counter_2d.row(2), carry, carry.xetla_format<uint32_t>());
111 counter_2d.row(3) += carry;
115template <u
int32_t SZ,
typename dtype_mask = u
int8_t, u
int32_t random_simd = 16>
124 uint32_t threshold_,
float scale_) {
126 this->threshold = threshold_;
127 this->scale = scale_;
130 template <
typename dtype>
134 for (uint32_t i = 0; i < SZ /
random_len; i++) {
139 out_sub.xetla_merge(0, mask_flag);
140 mask_sub.xetla_merge(1, 0, mask_flag);
141 out_sub = out_sub *
scale;
144 constexpr uint32_t remain_len = SZ %
random_len;
146 auto out_sub = output.xetla_select<remain_len, 1>(remain_start);
147 auto mask_sub =
mask.xetla_select<remain_len, 1>(remain_start);
151 out_sub.xetla_merge(0, mask_flag.xetla_select<remain_len, 1>(0));
152 mask_sub.xetla_merge(
153 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
154 out_sub = out_sub *
scale;
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
Definition arch_config.hpp:24
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_)
Definition rand.hpp:123
xetla_rand_t< random_simd > rand_gen
Definition rand.hpp:118
static constexpr uint32_t random_len
Definition rand.hpp:117
float scale
Definition rand.hpp:121
__XETLA_API xetla_vector< dtype_mask, SZ > get_mask()
Definition rand.hpp:159
__XETLA_API xetla_vector< dtype, SZ > process(xetla_vector< dtype, SZ > input)
Definition rand.hpp:131
uint32_t threshold
Definition rand.hpp:120
xetla_vector< dtype_mask, SZ > mask
Definition rand.hpp:119
static constexpr uint32_t kPhiloxSA
Definition rand.hpp:35
static constexpr uint32_t kPhilox10A
Definition rand.hpp:33
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
xetla_vector< uint32_t, 2 *SIMD > key
Definition rand.hpp:31
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38
static constexpr uint32_t kPhiloxSB
Definition rand.hpp:36
xetla_vector< uint32_t, 4 *SIMD > counter
Definition rand.hpp:32
static constexpr uint32_t kPhilox10B
Definition rand.hpp:34