XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
dropout_mask_gen.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "common/common.hpp"
24
25namespace gpu::xetla::group {
26
36template <typename dtype_mask_, uint32_t wg_tile_n_, uint32_t wg_tile_m_,
37 uint32_t sg_tile_n_, uint32_t sg_tile_m_, uint32_t random_simd_ = 16,
38 gpu_arch arch_ = gpu_arch::Xe>
39struct mask_gen_t {
40 using dtype_mask = dtype_mask_;
41 static constexpr uint32_t wg_tile_n = wg_tile_n_;
42 static constexpr uint32_t wg_tile_m = wg_tile_m_;
43 static constexpr uint32_t sg_tile_n = sg_tile_n_;
44 static constexpr uint32_t sg_tile_m = sg_tile_m_;
45 static constexpr uint32_t random_simd = random_simd_;
46
47 static constexpr uint32_t wg_size_x
48 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
49 static constexpr uint32_t wg_size_y
50 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
51
54 struct arguments_t {
56 uint32_t matrix_m;
57 uint32_t matrix_n;
58 uint32_t mask_ld;
59 uint64_t rand_seed = 67280421310721;
60 uint64_t *rand_offset_ptr;
62 };
63
64 using load_store_attr = typename arch_attr_t<
66 static constexpr uint32_t max_store_width_in_bytes
67 = load_store_attr::max_store_width_in_bytes;
68 static constexpr uint32_t max_store_width_in_elem
70 static constexpr uint32_t max_store_height_in_elem
71 = load_store_attr::max_store_height_in_elem;
72 static constexpr uint32_t tile_size_x = sg_tile_n;
73 static constexpr uint32_t tile_size_y = sg_tile_m;
74 // block_size_x should be power of 2 and tile_size_x should be divided by block_size_x
75 static constexpr uint32_t block_size_x
80 static_assert(block_size_x >= 8,
81 "if block_size_x less than 8, the efficiency will be low. Please "
82 "choose another tile_size_x");
83 static constexpr uint32_t block_size_y
86
94 gpu_arch::Xe>;
95 static constexpr uint32_t tile_size = tile_size_x * tile_size_y;
96
105 __XETLA_API KERNEL_FUNC void operator()(arguments_t *args, uint32_t wg_idx,
106 uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy,
107 uint32_t linear_idx) {
108 xetla_vector<uint64_t, 1> rand_offset_ptr_v
111 args->rand_offset_ptr, 0);
112 uint32_t threshold = uint32_t(args->dropout_prob * float(4294967296));
113 mask_out_tile_t mask_out;
114 int start_m = wg_idy * wg_tile_m + sg_idy * sg_tile_m;
115 int start_n = wg_idx * wg_tile_n + sg_idx * sg_tile_n;
116 mask_out_payload_t mask_out_payload(args->mask_ptr, args->matrix_n,
117 args->matrix_m, args->mask_ld, start_n, start_m);
118 static constexpr uint32_t random_len = 4 * random_simd;
120 rand_gen.init(args->rand_seed, linear_idx, rand_offset_ptr_v[0]);
121
123#pragma unroll
124 for (int i = 0; i < tile_size / random_len; i++) {
125 auto mask_sub = mask.xetla_select<random_len, 1>(i * random_len);
126 xetla_vector<uint32_t, random_len> rand_val = rand_gen.rand();
127 xetla_mask<random_len> mask_flag = rand_val < threshold;
128 mask_sub.xetla_merge(1, 0, mask_flag);
129 }
130 if constexpr (tile_size % random_len != 0) {
131 constexpr uint32_t remain_len = tile_size % random_len;
132 constexpr uint32_t remain_start
133 = tile_size / random_len * random_len;
134 auto mask_sub = mask.xetla_select<remain_len, 1>(remain_start);
135 // drop, still generate random_len
136 xetla_vector<uint32_t, random_len> rand_val = rand_gen.rand();
137 xetla_mask<random_len> mask_flag = rand_val < threshold;
138 mask_sub.xetla_merge(
139 1, 0, mask_flag.xetla_select<remain_len, 1>(0));
140 }
141 mask_out.reg = mask;
142 subgroup::tile_store<cache_hint::uncached>(mask_out, mask_out_payload);
143 }
144};
145} // namespace gpu::xetla::group
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API xetla_vector< Ty, N *NElts > xetla_load_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_mask< N > pred=1)
Stateless scattered load.
Definition memory.hpp:245
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
Definition arch_config.hpp:72
Definition dropout_mask_gen.hpp:54
uint32_t mask_ld
Definition dropout_mask_gen.hpp:58
uint32_t matrix_m
Definition dropout_mask_gen.hpp:56
uint64_t * rand_offset_ptr
Definition dropout_mask_gen.hpp:60
uint64_t rand_seed
Definition dropout_mask_gen.hpp:59
uint32_t matrix_n
Definition dropout_mask_gen.hpp:57
dtype_mask * mask_ptr
Definition dropout_mask_gen.hpp:55
float dropout_prob
Definition dropout_mask_gen.hpp:61
Definition dropout_mask_gen.hpp:39
static constexpr uint32_t random_simd
Definition dropout_mask_gen.hpp:45
static constexpr uint32_t block_size_x
Definition dropout_mask_gen.hpp:76
static constexpr uint32_t tile_size_x
Definition dropout_mask_gen.hpp:72
static constexpr uint32_t sg_tile_m
Definition dropout_mask_gen.hpp:44
static constexpr uint32_t tile_size_y
Definition dropout_mask_gen.hpp:73
typename arch_attr_t< arch_ >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition dropout_mask_gen.hpp:65
dtype_mask_ dtype_mask
Definition dropout_mask_gen.hpp:40
static constexpr uint32_t max_store_width_in_bytes
Definition dropout_mask_gen.hpp:67
static constexpr uint32_t tile_size
Definition dropout_mask_gen.hpp:95
static constexpr uint32_t block_size_y
Definition dropout_mask_gen.hpp:84
static constexpr uint32_t max_store_height_in_elem
Definition dropout_mask_gen.hpp:71
subgroup::tile_desc_t< tile_size_x, tile_size_y, block_size_x, block_size_y, reg_layout::tiled > mask_out_tile_desc_t
Definition dropout_mask_gen.hpp:88
static constexpr uint32_t wg_tile_n
Definition dropout_mask_gen.hpp:41
static constexpr uint32_t max_store_width_in_elem
Definition dropout_mask_gen.hpp:69
static constexpr uint32_t sg_tile_n
Definition dropout_mask_gen.hpp:43
static constexpr uint32_t wg_size_x
Definition dropout_mask_gen.hpp:48
__XETLA_API KERNEL_FUNC void operator()(arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t linear_idx)
Definition dropout_mask_gen.hpp:105
static constexpr uint32_t wg_size_y
Definition dropout_mask_gen.hpp:50
static constexpr uint32_t wg_tile_m
Definition dropout_mask_gen.hpp:42
Definition memory_descriptor.hpp:139
Definition common.hpp:80
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
Definition rand.hpp:30
__XETLA_API xetla_vector< uint32_t, 4 *SIMD > rand()
Definition rand.hpp:57
__XETLA_API void init(uint64_t seed, uint64_t subseq, uint64_t offset)
Definition rand.hpp:38
C++ API.