XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
row_reduction_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
27
28namespace gpu::xetla::kernel {
29
40template <typename dtype_in_, typename dtype_out_, typename dtype_acc_,
41 typename reduction_attr_, typename fused_op_t_>
42struct xetla_row_reduction_t<dtype_in_, dtype_out_, dtype_acc_, reduction_attr_,
43 gpu_arch::Xe, fused_op_t_> {
44 using dtype_in = dtype_in_;
45 using dtype_out = dtype_out_;
46 using dtype_acc = dtype_acc_;
47 using reduction_attr = reduction_attr_;
48 using fused_op_t = fused_op_t_;
49 using fused_op_arguments_t = typename fused_op_t::arguments_t;
50
51 static constexpr uint32_t wg_tile_m = reduction_attr::wg_tile_m;
52 static constexpr uint32_t wg_tile_n = reduction_attr::wg_tile_n;
53 static constexpr uint32_t sg_tile_m = reduction_attr::sg_tile_m;
54 static constexpr uint32_t sg_tile_n = reduction_attr::sg_tile_n;
55 static constexpr bool is_dynamic_job = reduction_attr::is_dynamic_job;
56 static constexpr uint32_t wg_size_x
57 = (wg_tile_n + sg_tile_n - 1) / sg_tile_n;
58 static constexpr uint32_t wg_size_y
59 = (wg_tile_m + sg_tile_m - 1) / sg_tile_m;
61 static constexpr bool use_dynamic_job = is_dynamic_job && (wg_size_y > 1);
62 using load_store_attr = typename arch_attr_t<
64 static constexpr uint32_t max_load_height_in_elem
65 = load_store_attr::max_load_height_in_elem;
66 static constexpr uint32_t max_load_width_in_bytes
67 = load_store_attr::max_load_width_in_bytes;
68 static constexpr uint32_t max_store_width_in_bytes
69 = load_store_attr::max_store_width_in_bytes;
70 static constexpr uint32_t max_load_width_in_elem
71 = max_load_width_in_bytes / sizeof(dtype_in);
72 static constexpr uint32_t max_store_width_in_elem
73 = max_store_width_in_bytes / sizeof(dtype_out);
74
75 static constexpr uint32_t tile_size_x = sg_tile_n;
76 static constexpr uint32_t tile_size_y = sg_tile_m;
77
78 static constexpr uint32_t max_simd_len = max_store_width_in_elem;
79
81 static constexpr uint32_t block_size_x
82 = max_load_width_in_elem > tile_size_x
83 ? tile_size_x
85 max_load_width_in_elem>::value;
86 static_assert(block_size_x >= 8,
87 "if block_size_x less than 8, the efficiency will be low. Please "
88 "choose another tile_size_x");
89 static constexpr uint32_t block_size_y
90 = max_load_height_in_elem > tile_size_y ? tile_size_y
91 : max_load_height_in_elem;
92
93 static constexpr uint32_t SIMD = 16;
94
96 tile_size_y, block_size_x, block_size_y, reg_layout::tiled>;
101 subgroup::msg_type_v<global_ld_tile_desc_t, mem_space::global>,
104 subgroup::tile_desc_t<tile_size_x, 1, block_size_x, 1,
108 dtype_out, sg_tile_n, wg_size_x, wg_size_y, max_simd_len>;
109
112 struct arguments_t {
115 uint32_t matrix_m;
116 uint32_t matrix_n;
117 uint32_t mat_in_ld;
118 };
119
122 struct get_barrier_count {
123 static constexpr uint32_t count = (wg_size_y > 1) ? wg_size_x : 0;
124 };
125
126 static constexpr uint32_t counter_size
127 = use_dynamic_job ? SIMD * sizeof(int) * wg_size_x : 0;
128 static constexpr uint32_t row_buffer_size = (wg_size_y > 1)
129 ? tile_size_x * wg_size_x * wg_size_y * sizeof(dtype_acc)
130 : 0;
131
134 struct get_slm_size {
135 static constexpr uint32_t size = row_buffer_size + counter_size;
136 };
137
147 __XETLA_API static void call(sycl::nd_item<3> &item, arguments_t *args,
148 fused_op_arguments_t *fused_op_args = nullptr,
149 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
150 work_group_t g;
151 g.init(item.get_local_linear_id());
152 int sg_idx = g.get_id() % wg_size_x;
153 int sg_idy = g.get_id() / wg_size_x;
154
155 int global_start_x_in
156 = item.get_group(2) * wg_tile_n + sg_idx * sg_tile_n;
157 int global_start_y_in = sg_idy * sg_tile_m;
159 nbarrier.init_nbarrier(
160 nbarrier_base + sg_idx, nbarrier_role::producer_consumer);
161 if constexpr (use_dynamic_job) {
164 slm_base + row_buffer_size + sg_idx * SIMD * sizeof(int));
165 xetla_mask<SIMD> pred(0);
166 pred[0] = 1;
167 if (sg_idy == 0) {
168 xetla_vector<int, SIMD> init(wg_size_y);
169 xetla_store_local<int, 1, data_size::default_size, SIMD>(
170 offsets, init, pred);
171 xetla_fence<memory_kind::shared_local>();
172 }
173 nbarrier.arrive();
174 }
175
176 global_ld_t mat_global_ld;
177 fused_op_t fused_op(
178 fused_op_args, global_start_x_in, global_start_y_in);
179 global_ld_payload_t mat_global_ld_payload(args->mat_in_ptr,
180 args->matrix_n, args->matrix_m, args->mat_in_ld,
181 global_start_x_in, global_start_y_in);
182 mat_buffer_t mat_buffer(0);
183 if constexpr (use_dynamic_job) {
184 nbarrier.wait();
185 int job_id = sg_idy;
187 slm_base + row_buffer_size + sg_idx * SIMD * sizeof(int));
188 xetla_mask<SIMD> pred(0);
189 pred[0] = 1;
190 while (job_id * tile_size_y < args->matrix_m) {
192 = xetla_atomic_local<atomic_op::iinc, int, SIMD>(
193 offsets, pred);
194 subgroup::tile_load(mat_global_ld, mat_global_ld_payload);
195 matAcc_t matAcc;
196 subgroup::elemwise_cvt<matAcc_t, global_ld_t>(
197 matAcc, mat_global_ld);
198 fused_op(matAcc);
200 dtype_acc, dtype_acc, 0>(matAcc);
201 mat_global_ld_payload
202 .template update_tdesc<tdesc_update_dir::y_dir>(
203 (next_job[0] - job_id) * tile_size_y);
204 fused_op.update_tdesc(0, (next_job[0] - job_id) * tile_size_y);
205 job_id = next_job[0];
206 }
207 } else {
208 for (int job_id = sg_idy; job_id * tile_size_y < args->matrix_m;
209 job_id += wg_size_y) {
210 subgroup::tile_load(mat_global_ld, mat_global_ld_payload);
211 matAcc_t matAcc;
212 subgroup::elemwise_cvt<matAcc_t, global_ld_t>(
213 matAcc, mat_global_ld);
214 fused_op(matAcc);
216 dtype_acc, dtype_acc, 0>(matAcc);
217 fused_op.update_tdesc(0, wg_size_y * tile_size_y);
218 mat_global_ld_payload
219 .template update_tdesc<tdesc_update_dir::y_dir>(
220 wg_size_y * tile_size_y);
221 }
222 }
223
224 row_reduce_store_t row_reduce_store;
225 uint32_t slm_row_reduce_base = slm_base;
226 uint32_t nbarrier_row_reduce_base = nbarrier_base;
227 row_reduce_store.init(
228 sg_idx, sg_idy, slm_row_reduce_base, nbarrier_row_reduce_base);
229 row_reduce_store(args->mat_out_ptr, args->matrix_n, 1, args->matrix_n,
230 global_start_x_in, 0, mat_buffer.reg);
231 }
232};
233
234} // namespace gpu::xetla::kernel
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
Definition limitation.hpp:734
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
C++ API.
Definition arch_config.hpp:72
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
typename arch_attr_t< gpu_arch::Xe >::template load_store_attr< msg_type::block_2d > load_store_attr
Definition row_reduction_xe.hpp:63
static __XETLA_API void call(sycl::nd_item< 3 > &item, arguments_t *args, fused_op_arguments_t *fused_op_args=nullptr, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Main execution function for row reduction.
Definition row_reduction_xe.hpp:147
Is the row_reduction functor.
Definition api.hpp:39
Definition memory_descriptor.hpp:139
Definition common.hpp:80
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76