XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
global_reduction.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/tile_shape.hpp"
23#include "subgroup/subgroup.hpp"
24
25namespace gpu::xetla::group {
26
36template <reduce_op reduce_kind, typename tile_shape_acc,
37 typename tile_shape_cnt, typename mem_desc_acc_t,
38 typename mem_desc_cnt_t, uint32_t num_group_reduction,
39 uint32_t counter_size, gpu_arch arch_tag, class enable = void>
41
43template <typename tile_shape_acc_, typename tile_shape_cnt_,
44 typename mem_desc_acc_t_, typename mem_desc_cnt_t_,
45 uint32_t num_group_reduction, uint32_t counter_size, gpu_arch arch_tag_>
46class global_reduce_t<reduce_op::sum, tile_shape_acc_, tile_shape_cnt_,
47 mem_desc_acc_t_, mem_desc_cnt_t_, num_group_reduction, counter_size,
48 arch_tag_, std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
49public:
50 static constexpr gpu_arch arch_tag = arch_tag_;
51 using tile_shape_acc = tile_shape_acc_;
52 using tile_shape_cnt = tile_shape_cnt_;
53 using mem_desc_acc_t = mem_desc_acc_t_;
54 using mem_desc_cnt_t = mem_desc_cnt_t_;
55 using dtype_acc = typename mem_desc_acc_t::dtype;
56 using dtype_cnt = typename mem_desc_cnt_t::dtype;
57
58private:
59 static constexpr uint32_t acc_sg_tile_y = tile_shape_acc::sg_tile_size_y;
60 static constexpr uint32_t acc_sg_tile_x = tile_shape_acc::sg_tile_size_x;
61 static constexpr uint32_t cnt_sg_tile_y = tile_shape_cnt::sg_tile_size_y;
62 static constexpr uint32_t cnt_sg_tile_x = tile_shape_cnt::sg_tile_size_x;
63 static constexpr uint32_t wg_size_x = tile_shape_acc::wg_size_x;
64 static constexpr uint32_t wg_size_y = tile_shape_acc::wg_size_y;
65 static_assert((tile_shape_acc::wg_size_x == tile_shape_cnt::wg_size_x)
66 && (tile_shape_acc::wg_size_y == tile_shape_cnt::wg_size_y),
67 "acc and cnt wg shape need to be matched");
68 using work_group_t = typename tile_shape_acc::work_group_t;
69
71 inline void update_sg_tile_tdesc(work_group_t &g,
72 mem_desc_acc_t &mem_desc_acc, mem_desc_cnt_t &mem_desc_cnt) {
73 int32_t sg_idx = g.get_id() % wg_size_x;
74 int32_t sg_idy = g.get_id() / wg_size_x;
75 int32_t acc_tile_offset_x = sg_idx * acc_sg_tile_x;
76 int32_t acc_tile_offset_y = sg_idy * acc_sg_tile_y;
77 mem_desc_acc.update_coord(acc_tile_offset_x, acc_tile_offset_y);
78 int32_t cnt_tile_offset_x = sg_idx * cnt_sg_tile_x;
79 int32_t cnt_tile_offset_y = sg_idy * cnt_sg_tile_y;
80 mem_desc_cnt.update_coord(cnt_tile_offset_x, cnt_tile_offset_y);
81 }
82
83 inline uint32_t update_reduce_counter(mem_desc_cnt_t &mem_desc_cnt) {
84 constexpr uint32_t SIMD = 16;
85 uint32_t pitch_in_bytes
86 = mem_desc_cnt.shape.stride * sizeof(dtype_cnt) * counter_size;
87 uint32_t offset_x = mem_desc_cnt.coord.x;
88 uint32_t offset_y = mem_desc_cnt.coord.y;
89 uint64_t address = (uint64_t)mem_desc_cnt.base.base
90 + offset_y * pitch_in_bytes
91 + offset_x * sizeof(dtype_cnt) * counter_size;
93 = xetla_vector_gen<uint32_t, SIMD>(0, 1);
94 offsets *= sizeof(dtype_cnt);
95 xetla_mask<SIMD> pred(0);
96 pred[0] = 1;
99 cache_hint::write_back>((dtype_cnt *)address, offsets, pred);
100 return ret[0];
101 }
102
103 inline void clean_reduce_counter(mem_desc_cnt_t &mem_desc_cnt) {
104 uint32_t pitch_in_bytes
105 = mem_desc_cnt.shape.stride * sizeof(dtype_cnt) * counter_size;
106 uint32_t offset_x = mem_desc_cnt.coord.x;
107 uint32_t offset_y = mem_desc_cnt.coord.y;
108 uint64_t address = (uint64_t)mem_desc_cnt.base.base
109 + offset_y * pitch_in_bytes
110 + offset_x * sizeof(dtype_cnt) * counter_size;
112
115 (dtype_cnt *)address, 0, zeros);
116 }
117
118public:
119 static constexpr uint32_t barrier_count = 0;
120 static constexpr uint32_t slm_size = 0;
121 uint32_t reduce_id = 0;
122
123 inline bool is_last_group() {
124 return reduce_id == (num_group_reduction - 1);
125 }
126
138 template <typename matAcc_t>
139 __XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc,
140 mem_desc_acc_t mem_desc_acc, mem_desc_cnt_t mem_desc_cnt,
141 [[maybe_unused]] uint32_t slm_base = 0,
142 [[maybe_unused]] uint32_t nbarrier_base = 0) {
143 static_assert(std::is_same<typename matAcc_t::dtype, dtype_acc>::value,
144 "matAcc_t::dtype should match with dtype_acc");
145 update_sg_tile_tdesc(g, mem_desc_acc, mem_desc_cnt);
146 using matAcc_tile_desc_t = typename matAcc_t::tile_desc;
147 using matAcc_store_payload_t = subgroup::mem_payload_t<mem_desc_acc_t,
148 matAcc_tile_desc_t, msg_type::atomic_add, arch_tag>;
149 matAcc_store_payload_t matAcc_store_payload(mem_desc_acc);
150 subgroup::tile_store<cache_hint::uncached, cache_hint::write_back>(
151 matAcc, matAcc_store_payload);
154 reduce_id = update_reduce_counter(mem_desc_cnt);
155 if (reduce_id == (num_group_reduction - 1)) {
156 using matAcc_payload_t = subgroup::mem_payload_t<mem_desc_acc_t,
157 matAcc_tile_desc_t, msg_type::block_2d, arch_tag>;
158 matAcc_payload_t matAcc_payload(mem_desc_acc);
159 subgroup::tile_load(matAcc, matAcc_payload);
160 clean_reduce_counter(mem_desc_cnt);
162 mat_zero_t mat_zero;
163 mat_zero.reg = 0;
164 subgroup::tile_store<cache_hint::uncached, cache_hint::write_back>(
165 mat_zero, matAcc_payload);
166 SW_BARRIER();
167 }
168 }
169};
170
172template <typename tile_shape_acc_, typename tile_shape_cnt_,
173 typename mem_desc_acc_t_, typename mem_desc_cnt_t_,
174 uint32_t counter_size_, gpu_arch arch_tag_>
175class global_reduce_t<reduce_op::sum, tile_shape_acc_, tile_shape_cnt_,
176 mem_desc_acc_t_, mem_desc_cnt_t_, 1, counter_size_, arch_tag_,
177 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
178public:
179 static constexpr gpu_arch arch_tag = arch_tag_;
180 using tile_shape_acc = tile_shape_acc_;
181 using tile_shape_cnt = tile_shape_cnt_;
182 using mem_desc_acc_t = mem_desc_acc_t_;
183 using mem_desc_cnt_t = mem_desc_cnt_t_;
184 using dtype_acc = typename mem_desc_acc_t::dtype;
185
186private:
187 using work_group_t = typename tile_shape_acc::work_group_t;
188
189public:
190 static constexpr uint32_t barrier_count = 0;
191 static constexpr uint32_t slm_size = 0;
192 inline bool is_last_group() { return true; }
193
194 template <typename matAcc_t>
195 inline KERNEL_FUNC void operator()([[maybe_unused]] work_group_t &g,
196 [[maybe_unused]] matAcc_t &matAcc,
197 [[maybe_unused]] mem_desc_acc_t mem_desc_acc,
198 [[maybe_unused]] mem_desc_cnt_t mem_desc_cnt,
199 [[maybe_unused]] uint32_t slm_base = 0,
200 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
201};
202
203} // namespace gpu::xetla::group
KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_acc_t mem_desc_acc, mem_desc_cnt_t mem_desc_cnt, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition global_reduction.hpp:195
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_acc_t mem_desc_acc, mem_desc_cnt_t mem_desc_cnt, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Global reduction.
Definition global_reduction.hpp:139
Cross group global reduction.
Definition global_reduction.hpp:40
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define SIMD
Definition gemm_softmax.cpp:23
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
__XETLA_API xetla_vector< T, N > xetla_atomic_global(T *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred)
Stateless scattered atomic (0 src).
Definition memory.hpp:371
__XETLA_API void xetla_store_global(Ty *p, xetla_vector< Toffset, N > offsets, xetla_vector< Ty, N *NElts > vals, xetla_mask< N > pred=1)
Stateless scattered store.
Definition memory.hpp:316
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
@ tile
flush out to the local scope
reduce_op
xetla reduce op
Definition common.hpp:217
@ iinc
Atomic increment of memory data and return the old value. see
gpu_arch
Definition common.hpp:73
Is to illustrate the memory information.
Definition api.hpp:44
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102