XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
row_reduction_fused_op_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::group {
25
31template <typename dtype_in, typename dtype_out, typename dtype_acc>
33 dtype_in *gelu_bwd_w_ptr;
34 dtype_out *gelu_bwd_x_ptr;
35 dtype_out *dropout_bwd_ptr;
36 uint8_t *mask_ptr;
39 uint32_t matrix_m;
40 uint32_t matrix_n;
41 uint32_t mat_in_ld;
42 uint32_t mat_out_ld;
43};
44
52template <reduction_fused_kind fused_op_kind_, typename dtype_in_,
53 typename dtype_out_, typename dtype_acc_, typename reduction_attr_>
54struct row_reduction_fused_op_t<fused_op_kind_, dtype_in_, dtype_out_,
55 dtype_acc_, reduction_attr_, gpu_arch::Xe> {
56 static constexpr reduction_fused_kind fused_op_kind = fused_op_kind_;
57 using dtype_in = dtype_in_;
58 using dtype_out = dtype_out_;
59 using dtype_acc = dtype_acc_;
63 [[maybe_unused]] int start_n = 0,
64 [[maybe_unused]] int start_m = 0) {}
65 template <typename matAcc_t>
66 __XETLA_API KERNEL_FUNC void operator()([[maybe_unused]] matAcc_t &matAcc) {
67 }
68 __XETLA_API void update_tdesc([[maybe_unused]] int offset_n = 0,
69 [[maybe_unused]] int offset_m = 0) {}
70};
71
72template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
73 typename reduction_attr_>
75 dtype_in_, dtype_out_, dtype_acc_, reduction_attr_, gpu_arch::Xe> {
76 static constexpr reduction_fused_kind fused_op_kind
78 using dtype_in = dtype_in_;
79 using dtype_out = dtype_out_;
80 using dtype_acc = dtype_acc_;
88 arguments_t *args, int start_n = 0, int start_m = 0) {
89 w_load_base_desc.init({args->gelu_bwd_w_ptr},
90 {args->matrix_n, args->matrix_m, args->mat_in_ld},
91 {start_n, start_m});
92 x_store_base_desc.init({args->gelu_bwd_x_ptr},
93 {args->matrix_n, args->matrix_m, args->mat_out_ld},
94 {start_n, start_m});
95 }
96
102 template <typename matAcc_t>
103 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc) {
104 static_assert(std::is_same<remove_const_t<dtype_acc>,
105 typename matAcc_t::dtype>::value,
106 "dtype_acc should match with matAcc");
107 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
108 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
109 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
110 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
111 static constexpr uint32_t num_elems = matAcc_t::tile_elems;
112 using dgelu_tile_desc_t = subgroup::tile_desc_t<tile_size_x,
113 tile_size_y, block_size_x, block_size_y, reg_layout::tiled>;
115 using dgelu_w_in_payload_t = subgroup::mem_payload_t<
117 dgelu_tile_desc_t,
118 subgroup::msg_type_v<dgelu_tile_desc_t, mem_space::global>,
121 using dgelu_x_out_payload_t = subgroup::mem_payload_t<
123 dgelu_tile_desc_t, msg_type::block_2d, gpu_arch::Xe>;
124 dgelu_w_in_t dgelu_w_in;
125 dgelu_w_in_payload_t dgelu_w_in_payload(w_load_base_desc);
126 subgroup::tile_load(dgelu_w_in, dgelu_w_in_payload);
128 = xetla_cvt<dtype_acc, dtype_in, num_elems>(dgelu_w_in.reg);
129 matAcc.reg = matAcc.reg * w;
130 dgelu_x_out_t dgelu_x_out;
131 dgelu_x_out_payload_t dgelu_x_out_payload(x_store_base_desc);
132 subgroup::elemwise_cvt(dgelu_x_out, matAcc);
133 subgroup::tile_store<cache_hint::uncached>(
134 dgelu_x_out, dgelu_x_out_payload);
135 }
136
142 __XETLA_API void update_tdesc(int offset_n = 0, int offset_m = 0) {
143 w_load_base_desc.update_coord(offset_n, offset_m);
144 x_store_base_desc.update_coord(offset_n, offset_m);
145 }
146};
147
148template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
149 typename reduction_attr_>
151 dtype_in_, dtype_out_, dtype_acc_, reduction_attr_, gpu_arch::Xe> {
152 static constexpr reduction_fused_kind fused_op_kind
154 using dtype_in = dtype_in_;
155 using dtype_out = dtype_out_;
156 using dtype_acc = dtype_acc_;
157 using dtype_mask = uint8_t;
166
168 arguments_t *args, int start_n = 0, int start_m = 0) {
169
170 mask_load_base_desc.init({args->mask_ptr},
171 {args->matrix_n, args->matrix_m, args->mat_in_ld},
172 {start_n, start_m});
173 dropout_bwd_store_base_desc.init({args->dropout_bwd_ptr},
174 {args->matrix_n, args->matrix_m, args->mat_out_ld},
175 {start_n, start_m});
176 dropout_scale_inv = args->dropout_scale_inv;
177 dropout_prob = args->dropout_prob;
178 }
179
180 template <typename matAcc_t>
181 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc) {
182 static_assert(std::is_same<remove_const_t<dtype_acc>,
183 typename matAcc_t::dtype>::value,
184 "dtype_acc should match with matAcc");
185 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
186 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
187 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
188 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
189 using reduction_tile_desc_t = subgroup::tile_desc_t<tile_size_x,
190 tile_size_y, block_size_x, block_size_y, reg_layout::tiled>;
192 using mask_in_payload_t = subgroup::mem_payload_t<
195 reduction_tile_desc_t,
196 subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
198 using dropout_bwd_out_t
200 using dropout_bwd_out_payload_t = subgroup::mem_payload_t<
202 reduction_tile_desc_t,
203 subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
205 if (dropout_prob != 0) {
206 mask_in_t mask_in;
207 mask_in_payload_t mask_in_payload(mask_load_base_desc);
208 subgroup::tile_load(mask_in, mask_in_payload);
209 SW_BARRIER();
210 matAcc.reg = drop_out<dtype_acc, tile_size_x * tile_size_y>(
211 matAcc.reg, mask_in.reg, dropout_scale_inv);
212 }
213 dropout_bwd_out_t dropout_bwd_out;
214 dropout_bwd_out_payload_t dropout_bwd_out_payload(
215 dropout_bwd_store_base_desc);
216 subgroup::elemwise_cvt(dropout_bwd_out, matAcc);
217 subgroup::tile_store<cache_hint::uncached>(
218 dropout_bwd_out, dropout_bwd_out_payload);
219 }
220
221 __XETLA_API void update_tdesc(int offset_n = 0, int offset_m = 0) {
222 mask_load_base_desc.update_coord(offset_n, offset_m);
223 dropout_bwd_store_base_desc.update_coord(offset_n, offset_m);
224 }
225};
226
227} // namespace gpu::xetla::group
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
gpu_arch
Definition common.hpp:73
reduction_fused_kind
Definition row_reduction_fused_op_api.hpp:28
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:62
__XETLA_API void update_tdesc(int offset_n=0, int offset_m=0)
Definition row_reduction_fused_op_xe.hpp:68
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc)
Definition row_reduction_fused_op_xe.hpp:66
mem_desc_t< dtype_out, mem_layout::row_major, mem_space::global > dropout_bwd_store_base_desc
Definition row_reduction_fused_op_xe.hpp:163
mem_desc_t< dtype_mask, mem_layout::row_major, mem_space::global > mask_load_base_desc
Definition row_reduction_fused_op_xe.hpp:161
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:167
mem_desc_t< dtype_in, mem_layout::row_major, mem_space::global > w_load_base_desc
Definition row_reduction_fused_op_xe.hpp:84
mem_desc_t< dtype_out, mem_layout::row_major, mem_space::global > x_store_base_desc
Definition row_reduction_fused_op_xe.hpp:86
__XETLA_API row_reduction_fused_op_t(arguments_t *args, int start_n=0, int start_m=0)
Definition row_reduction_fused_op_xe.hpp:87
Additional Ops that can be fused with row reduction processing flow.
Definition row_reduction_fused_op_api.hpp:47
Definition row_reduction_fused_op_xe.hpp:32
uint8_t * mask_ptr
Definition row_reduction_fused_op_xe.hpp:36
uint32_t mat_in_ld
Definition row_reduction_fused_op_xe.hpp:41
dtype_in * gelu_bwd_w_ptr
Definition row_reduction_fused_op_xe.hpp:33
float dropout_scale_inv
Definition row_reduction_fused_op_xe.hpp:38
uint32_t matrix_n
Definition row_reduction_fused_op_xe.hpp:40
dtype_out * dropout_bwd_ptr
Definition row_reduction_fused_op_xe.hpp:35
float dropout_prob
Definition row_reduction_fused_op_xe.hpp:37
uint32_t matrix_m
Definition row_reduction_fused_op_xe.hpp:39
uint32_t mat_out_ld
Definition row_reduction_fused_op_xe.hpp:42
dtype_out * gelu_bwd_x_ptr
Definition row_reduction_fused_op_xe.hpp:34
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99