XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
rand.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla {
25
28
29template <uint32_t SIMD = 16, uint32_t round = 7>
33 static constexpr uint32_t kPhilox10A = 0x9E3779B9;
34 static constexpr uint32_t kPhilox10B = 0xBB67AE85;
35 static constexpr uint32_t kPhiloxSA = 0xD2511F53;
36 static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
37
38 __XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset) {
39 xetla_vector<uint64_t, 1> seed_v = seed;
40 xetla_vector<uint64_t, 1> offset_v = offset;
41 xetla_vector<uint64_t, 1> subseq_v = subseq * SIMD;
42
43 auto key_2d = key.xetla_format<uint32_t, 2, SIMD>();
44 key_2d.row(0) = uint32_t(seed_v.xetla_format<uint32_t>()[0]);
45 key_2d.row(1) = uint32_t(seed_v.xetla_format<uint32_t>()[1]);
46
48 = xetla_vector_gen<uint32_t, SIMD>(0, 1);
49 auto counter_2d = counter.xetla_format<uint32_t, 4, SIMD>();
50 counter_2d.row(0) = uint32_t(offset_v.xetla_format<uint32_t>()[0]);
51 counter_2d.row(1) = uint32_t(offset_v.xetla_format<uint32_t>()[1]);
52 counter_2d.row(2) = uint32_t(subseq_v.xetla_format<uint32_t>()[0]);
53 counter_2d.row(3) = uint32_t(subseq_v.xetla_format<uint32_t>()[1]);
54 counter_2d.row(2) += channel_id;
55 }
56
60 auto key_2d_ = key_.xetla_format<uint32_t, 2, SIMD>();
61
62#pragma unroll
63 for (uint32_t i = 0; i < round; i++) {
64 counter_ = single_round(counter_, key_);
65 key_2d_.row(0) += kPhilox10A;
66 key_2d_.row(1) += kPhilox10B;
67 }
68 xetla_vector<uint32_t, 4 *SIMD> output = single_round(counter_, key_);
69 incr();
70 return output;
71 }
72
73private:
78 auto ret_2d = ret.xetla_format<uint32_t, 4, SIMD>();
79 auto key_2d_ = key_.xetla_format<uint32_t, 2, SIMD>();
80 auto counter_2d_ = counter_.xetla_format<uint32_t, 4, SIMD>();
81
85 = xetla_imul<uint32_t, uint32_t, uint32_t, SIMD>(
86 res0_lo.xetla_format<uint32_t>(), counter_2d_.row(0),
87 kPhiloxSA);
89 = xetla_imul<uint32_t, uint32_t, uint32_t, SIMD>(
90 res1_lo.xetla_format<uint32_t>(), counter_2d_.row(2),
91 kPhiloxSB);
92
93 ret_2d.row(0) = res1_hi ^ counter_2d_.row(1) ^ key_2d_.row(0);
94 ret_2d.row(1) = res1_lo;
95 ret_2d.row(2) = res0_hi ^ counter_2d_.row(3) ^ key_2d_.row(1);
96 ret_2d.row(3) = res0_lo;
97
98 return ret;
99 }
100
101 __XETLA_API void incr() {
102 auto counter_2d = counter.xetla_format<uint32_t, 4, SIMD>();
103 xetla_vector<uint32_t, SIMD> carry;
104
105 counter_2d.row(0) = xetla_add_c<uint32_t, SIMD>(
106 counter_2d.row(0), 1, carry.xetla_format<uint32_t>());
107 counter_2d.row(1) = xetla_add_c<uint32_t, SIMD>(
108 counter_2d.row(1), carry, carry.xetla_format<uint32_t>());
109 counter_2d.row(2) = xetla_add_c<uint32_t, SIMD>(
110 counter_2d.row(2), carry, carry.xetla_format<uint32_t>());
111 counter_2d.row(3) += carry;
112 }
113};
114
115template <uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
117 static constexpr uint32_t random_len = 4 * random_simd;
120 uint32_t threshold;
121 float scale;
122
123 __XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset,
124 uint32_t threshold_, float scale_) {
125 rand_gen.init(seed, subseq, offset);
126 this->threshold = threshold_;
127 this->scale = scale_;
128 }
129
130 template <typename dtype>
132 xetla_vector<dtype, SZ> output = input;
133#pragma unroll
134 for (uint32_t i = 0; i < SZ / random_len; i++) {
135 auto out_sub = output.xetla_select<random_len, 1>(i * random_len);
136 auto mask_sub = mask.xetla_select<random_len, 1>(i * random_len);
138 xetla_mask<random_len> mask_flag = rand_val < threshold;
139 out_sub.xetla_merge(0, mask_flag);
140 mask_sub.xetla_merge(1, 0, mask_flag);
141 out_sub = out_sub * scale;
142 }
143 if constexpr (SZ % random_len != 0) {
144 constexpr uint32_t remain_len = SZ % random_len;
145 constexpr uint32_t remain_start = SZ / random_len * random_len;
146 auto out_sub = output.xetla_select<remain_len, 1>(remain_start);
147 auto mask_sub = mask.xetla_select<remain_len, 1>(remain_start);
148 // dropout, still generate random_len
150 xetla_mask<random_len> mask_flag = rand_val < threshold;
151 out_sub.xetla_merge(0, mask_flag.xetla_select<remain_len, 1>(0));
152 mask_sub.xetla_merge(
153 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
154 out_sub = out_sub * scale;
155 }
156 return output;
157 }
158
160 return mask;
161 }
162};
163
165
166} // namespace gpu::xetla
#define __XETLA_API
Definition common.hpp:43
C++ API.
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
Definition arch_config.hpp:24
Definition rand.hpp:116
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_)
Definition rand.hpp:123
xetla_rand_t< random_simd > rand_gen
Definition rand.hpp:118
static constexpr uint32_t random_len
Definition rand.hpp:117
float scale
Definition rand.hpp:121
__XETLA_API xetla_vector< dtype_mask, SZ > get_mask()
Definition rand.hpp:159
__XETLA_API xetla_vector< dtype, SZ > process(xetla_vector< dtype, SZ > input)
Definition rand.hpp:131
uint32_t threshold
Definition rand.hpp:120
xetla_vector< dtype_mask, SZ > mask
Definition rand.hpp:119
Definition rand.hpp:30
static constexpr uint32_t kPhiloxSA
Definition rand.hpp:35
static constexpr uint32_t kPhilox10A
Definition rand.hpp:33
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
xetla_vector< uint32_t, 2 *SIMD > key
Definition rand.hpp:31
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38
static constexpr uint32_t kPhiloxSB
Definition rand.hpp:36
xetla_vector< uint32_t, 4 *SIMD > counter
Definition rand.hpp:32
static constexpr uint32_t kPhilox10B
Definition rand.hpp:34