XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
cooperative_reduction.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/tile_shape.hpp"
23#include "subgroup/subgroup.hpp"
24
25namespace gpu::xetla::group {
26
33template <reduce_op reduce_kind, typename tile_shape, typename matAcc_t,
34 uint32_t num_cooperative_wg, gpu_arch arch_tag, class enable = void>
36
38template <reduce_op reduce_kind, typename tile_shape_, typename matAcc_t,
39 uint32_t num_cooperative_wg, gpu_arch arch_tag_>
40class cooperative_reduce_t<reduce_kind, tile_shape_, matAcc_t,
41 num_cooperative_wg, arch_tag_,
42 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
43public:
44 static constexpr gpu_arch arch_tag = arch_tag_;
45 using tile_shape = tile_shape_;
46 using dtype = typename matAcc_t::dtype;
47
48private:
49 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
50 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
51 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
52 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
53 static constexpr uint32_t real_wg_tile_m = sg_tile_m * wg_size_y;
54 static constexpr uint32_t real_wg_tile_n = sg_tile_n * wg_size_x;
55
56 static constexpr uint32_t wg_tile_size
57 = real_wg_tile_m * real_wg_tile_n * sizeof(dtype);
58 using work_group_t = typename tile_shape::work_group_t;
59 static constexpr uint32_t work_group_size = work_group_t::size;
60 // cooperative split, y dir first
61 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
62 "num_cooperative_wg should be power of 2");
63
64public:
65 static constexpr uint32_t coop_num_y
66 = gpu::xetla::subgroup::detail::gcd<num_cooperative_wg,
67 sg_tile_m>::value;
68 static constexpr uint32_t coop_remain_num_x
69 = num_cooperative_wg / coop_num_y;
70 static constexpr bool has_redundant_wg
71 = (coop_remain_num_x * 16) > sg_tile_n;
72 static constexpr uint32_t tile_size_y = sg_tile_m / coop_num_y;
73 static constexpr uint32_t tile_size_x
74 = has_redundant_wg ? 16 : sg_tile_n / coop_remain_num_x;
75 static constexpr uint32_t coop_num_x = sg_tile_n / tile_size_x;
76 static constexpr uint32_t num_reduce_wg = coop_num_x * coop_num_y;
77
78private:
79 static constexpr uint32_t src_block_size_x = matAcc_t::block_size_x;
80 static constexpr uint32_t src_block_size_y = matAcc_t::block_size_y;
81
82 static constexpr uint32_t block_size_x
84 src_block_size_x>::value;
85 static constexpr uint32_t block_size_y
86 = (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y;
87
88 using local_st_tile_desc_t = subgroup::tile_desc_t<sg_tile_n, sg_tile_m,
89 src_block_size_x, src_block_size_y, reg_layout::tiled>;
94 subgroup::msg_type_v<local_st_tile_desc_t, mem_space::local>,
95 arch_tag>;
96 using local_ld_tile_desc_t = subgroup::tile_desc_t<tile_size_x, tile_size_y,
97 block_size_x, block_size_y, reg_layout::tiled>;
102 subgroup::msg_type_v<local_ld_tile_desc_t, mem_space::local>,
103 arch_tag>;
104
105public:
107
108 static constexpr uint32_t barrier_count = work_group_size;
109 static constexpr uint32_t slm_size = wg_tile_size * num_cooperative_wg;
110
111 uint32_t coop_id;
112 uint32_t coop_id_x;
113 uint32_t coop_id_y;
114 inline cooperative_reduce_t(uint32_t coop_id_) : coop_id(coop_id_) {
115 coop_id_x = coop_id % coop_remain_num_x;
116 coop_id_y = coop_id / coop_remain_num_x;
117 }
118 inline bool is_valid_post_process_wg() { return coop_id_x < coop_num_x; }
119
130 inline KERNEL_FUNC void operator()(work_group_t &g, mat_slice_t &mat_slice,
131 matAcc_t &matAcc, uint32_t slm_base = 0,
132 uint32_t nbarrier_base = 0) {
133 uint32_t sg_idx = g.get_id() % wg_size_x;
134 uint32_t sg_idy = g.get_id() / wg_size_x;
135
136 int32_t slm_store_offset_x = sg_idx * sg_tile_n;
137 int32_t slm_store_offset_y
138 = coop_id * real_wg_tile_m + sg_idy * sg_tile_m;
139 local_st_tile_t local_st;
140 local_st_payload_t local_st_payload(slm_base, real_wg_tile_n,
141 real_wg_tile_m * num_cooperative_wg, real_wg_tile_n,
142 slm_store_offset_x, slm_store_offset_y);
143 local_st.reg = matAcc.reg;
144 tile_store(local_st, local_st_payload);
145
147 nbarrier;
148 uint32_t nbar_id = nbarrier_base + g.get_id();
150 xetla_fence<memory_kind::shared_local>();
151 nbarrier.arrive();
152 nbarrier.wait();
153
154 if (is_valid_post_process_wg()) {
155 // nbarrier.init_nbarrier(nbar_id, nbarrier_role::consumer);
156 // nbarrier.arrive();
157 int32_t slm_load_offset_x
158 = sg_idx * sg_tile_n + coop_id_x * tile_size_x;
159 int32_t slm_load_offset_y
160 = sg_idy * sg_tile_m + coop_id_y * tile_size_y;
161
162 local_ld_tile_t local_ld;
163 local_ld_payload_t local_ld_payload(slm_base, real_wg_tile_n,
164 real_wg_tile_m * num_cooperative_wg, real_wg_tile_n,
165 slm_load_offset_x, slm_load_offset_y);
166
167 tile_load(local_ld, local_ld_payload);
168 mat_slice.reg = local_ld.reg;
169#pragma unroll
170 for (uint32_t i = 1; i < num_cooperative_wg; i++) {
171 local_ld_payload.template update_tdesc<tdesc_update_dir::y_dir>(
172 real_wg_tile_m);
173 tile_load(local_ld, local_ld_payload);
174 mat_slice.reg = reduce_helper<reduce_kind, dtype>(
175 mat_slice.reg, local_ld.reg);
176 }
177 }
178 }
179};
180
183template <reduce_op reduce_kind, typename tile_shape_, typename matAcc_t,
184 gpu_arch arch_tag_>
185class cooperative_reduce_t<reduce_kind, tile_shape_, matAcc_t, 1, arch_tag_,
186 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
187public:
188 static constexpr gpu_arch arch_tag = arch_tag_;
189 using tile_shape = tile_shape_;
190 using dtype = typename matAcc_t::dtype;
191
192private:
193 using work_group_t = typename tile_shape::work_group_t;
194
195public:
196 using mat_slice_t = matAcc_t;
197 static constexpr uint32_t barrier_count = 0;
198 static constexpr uint32_t slm_size = 0;
199 static constexpr uint32_t coop_num_x = 1;
200 static constexpr uint32_t coop_num_y = 1;
201 uint32_t coop_id;
202 uint32_t coop_id_x;
203 uint32_t coop_id_y;
204 inline cooperative_reduce_t([[maybe_unused]] uint32_t coop_id_) {
205 coop_id = 0;
206 coop_id_x = 0;
207 coop_id_y = 0;
208 }
209 inline bool is_valid_post_process_wg() { return true; }
210
211 inline KERNEL_FUNC void operator()([[maybe_unused]] work_group_t &g,
212 mat_slice_t &mat_slice, matAcc_t &matAcc,
213 [[maybe_unused]] uint32_t slm_base = 0,
214 [[maybe_unused]] uint32_t nbarrier_base = 0) {
215 mat_slice.reg = matAcc.reg;
216 }
217};
218
219} // namespace gpu::xetla::group
KERNEL_FUNC void operator()(work_group_t &g, mat_slice_t &mat_slice, matAcc_t &matAcc, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Cooperative workgroup reduction.
Definition cooperative_reduction.hpp:130
KERNEL_FUNC void operator()(work_group_t &g, mat_slice_t &mat_slice, matAcc_t &matAcc, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition cooperative_reduction.hpp:211
Workgroups to do the cooperative reduction.
Definition cooperative_reduction.hpp:35
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
Definition memory_descriptor.hpp:139
Definition common.hpp:80
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
static constexpr uint32_t size
Definition work_group.hpp:39
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76