XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
stream_k_op_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
25
26namespace gpu::xetla::group {
27
30
32template <typename tile_shape_, typename epilogue_t_, typename mem_desc_d_t_,
33 typename mem_desc_atomic_sync_t_>
35
36 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
37 using epilogue_t = epilogue_t_;
38 using mem_desc_d_t = mem_desc_d_t_;
39 using mem_desc_c_t = typename epilogue_t::mem_desc_c_t;
40 using mem_desc_atomic_sync_t = mem_desc_atomic_sync_t_;
41 using tile_shape = tile_shape_;
42 using epilogue_args_t = typename epilogue_t::arguments_t;
43
44 using work_group_t = typename tile_shape::work_group_t;
45 static constexpr uint32_t wg_tile_m = tile_shape::wg_tile_size_y;
46 static constexpr uint32_t wg_tile_n = tile_shape::wg_tile_size_x;
47 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
48 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
49 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
50 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
51
52 //Barrier required to synchronize all threads in workgroup for atomic sync across xecores
53 static constexpr uint32_t barrier_count = 1;
54 static constexpr uint32_t slm_size
55 = mem_desc_c_t::is_local ? wg_tile_m * wg_tile_n : 0;
56 static constexpr uint32_t N_SG = wg_size_x * wg_size_y;
57
59
60 using dtype_d = typename mem_desc_d_t::dtype;
61 using dtype_flag = typename mem_desc_atomic_sync_t::dtype;
62
63 //Use special residual op for finishing SK groups to read from scratchspace buffer and reduce in GRF; They also store zeros in scratchspace buffer
66 using residual_op_args_t = typename residual_op_t::arguments_t;
67
68 static constexpr mem_layout mem_layout_d = mem_desc_d_t::layout;
69 static constexpr mem_space mem_space_d = mem_desc_d_t::space;
72
75 work_group_t &g, mem_desc_d_t &mem_desc_d) {
76 int32_t sg_idx = g.get_id() % wg_size_x;
77 int32_t sg_idy = g.get_id() / wg_size_x;
78 int32_t tile_offset_n = sg_idx * sg_tile_n;
79 int32_t tile_offset_m = sg_idy * sg_tile_m;
80 mem_desc_d.update_coord(tile_offset_n, tile_offset_m);
81 }
82
93 template <typename matAcc_t>
95 mem_desc_c_t mem_desc_c, mem_desc_d_t mem_desc_d,
96 mem_desc_atomic_sync_t mem_desc_atomic_sync, int group_idx,
97 int first_group_idx, bool tile_finished, bool tile_started,
98 epilogue_args_t epilogue_args, uint32_t slm_base = 0,
99 uint32_t nbarrier_base = 0) {
100
101 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
102 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
103 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
104 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
105
106 using matD_tile_desc_t = subgroup::tile_desc_t<tile_size_x, tile_size_y,
107 block_size_x, block_size_y, reg_layout::tiled>;
108
109 using matD_atomic_payload_t = subgroup::mem_payload_t<mem_desc_d_t,
110 matD_tile_desc_t, msg_type_d_atomic, arch_tag>;
111
112 uint32_t nbarrier_id = nbarrier_base;
114
115 update_sg_tile_tdesc(g, mem_desc_d);
116
117 //Addressing for atomic signal
118 xetla_mask<16> pred(0);
119 pred[0] = 1;
120 xetla_vector<uint32_t, 16> flag_offsets
121 = xetla_vector_gen<uint32_t, 16>(0, 1);
122 flag_offsets
123 += first_group_idx; // first_group_idx indicates the first peer of the sliced tile
124 flag_offsets = flag_offsets * sizeof(dtype_flag);
125 int32_t sg_id = g.get_id();
126 dtype_flag *flag_pointer = mem_desc_atomic_sync.base.base;
127
128 //SK group , Sliced Tile - SK group handles starting slice or middle slice
129 if (!tile_finished) {
130
131 //Perform atomic writes and signal to atomic counter
132 matD_atomic_payload_t matD_atomic_payload(mem_desc_d);
133 //Atomic store with OOB check
134 subgroup::tile_store(matAcc, matD_atomic_payload);
135
136 //Fence to guarantee write completion
139
140 //Group sync to make sure fence is sent
142 nbarrier.wait();
143
144 //Signal to other peers
145 if (sg_id == 0) {
146 xetla_vector<dtype_flag, 16> signal_val(1);
149 (uint64_t)flag_pointer, flag_offsets, signal_val, pred);
150 }
151
152 } else {
153
154 //last SK group of corresponding sliced tile
155 if (!tile_started) {
156
157 //Number of previous peers that have contributed to this sliced tile
158 uint32_t num_peers = group_idx - first_group_idx;
159
160 //Group sync
162 nbarrier.wait();
163
164 if (sg_id == 0) {
165
167 xetla_vector<dtype_flag, 16> old_val = num_peers;
169
170 //Use atomic cmpxchg to test if previous peers have finished writing
171 //Exchange with value zero to clear the flag
172 while (ret_val[0] != num_peers) {
173
177 flag_pointer, flag_offsets, old_val, zero_val,
178 pred);
179 }
180 }
181 //Group sync
183 nbarrier.wait();
184
185 //Invoke stream_k residual op
186 residual_op_t residual_op;
187 residual_op_args_t residual_args(
188 mem_desc_d.base, mem_desc_d.shape);
189
190 residual_op(matAcc, mem_desc_d.coord, residual_args);
191 }
192
193 //Finishing SK groups and DP Groups perform normal epilogue operations - post_op fusion + output conversion and write to output buffer
194 epilogue_t epilogue;
195 epilogue(g, matAcc, mem_desc_c, epilogue_args, slm_base,
196 nbarrier_base);
197 }
198 }
199};
200
202
203} // namespace gpu::xetla::group
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API void xetla_fence(xetla_mask< N > pred=1)
Memory fence.
Definition memory.hpp:638
__XETLA_API xetla_vector< T, N > xetla_atomic_global(T *p, xetla_vector< uint32_t, N > offsets, xetla_mask< N > pred)
Stateless scattered atomic (0 src).
Definition memory.hpp:371
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API std::enable_if_t< arch_tag==gpu_arch::Xe, void > xetla_tatomic_store_global(uint64_t base_address, xetla_vector< Toffset, N > offset, xetla_vector< Ty, N > data, xetla_mask< N > pred=1)
Tensor atomic store API.
Definition raw_send_load_store.hpp:294
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
@ evict
no operation
@ tile
flush out to the local scope
mem_space
Definition common.hpp:77
@ iadd
Atomic signed int add of src1 from memory data and return the old value. see
@ cmpxchg
Atomic bit-compare src1_X and memory data and replace if equal with src1_Y. Returns the old value....
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Is the epilogue functor specialized for stream_k.
Definition stream_k_op_xe.hpp:34
xetla_nbarrier_t< N_SG, N_SG, arch_tag > nbarrier
Definition stream_k_op_xe.hpp:58
static constexpr uint32_t slm_size
Definition stream_k_op_xe.hpp:55
static __XETLA_API void update_sg_tile_tdesc(work_group_t &g, mem_desc_d_t &mem_desc_d)
Updates tile base descriptor based on the tid.
Definition stream_k_op_xe.hpp:74
static constexpr uint32_t wg_tile_m
Definition stream_k_op_xe.hpp:45
static constexpr uint32_t wg_size_x
Definition stream_k_op_xe.hpp:49
typename mem_desc_atomic_sync_t::dtype dtype_flag
Definition stream_k_op_xe.hpp:61
static constexpr uint32_t sg_tile_m
Definition stream_k_op_xe.hpp:47
static constexpr mem_layout mem_layout_d
Definition stream_k_op_xe.hpp:68
typename epilogue_t::arguments_t epilogue_args_t
Definition stream_k_op_xe.hpp:42
static constexpr msg_type msg_type_d_block2d
Definition stream_k_op_xe.hpp:70
mem_desc_atomic_sync_t_ mem_desc_atomic_sync_t
Definition stream_k_op_xe.hpp:40
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_c_t mem_desc_c, mem_desc_d_t mem_desc_d, mem_desc_atomic_sync_t mem_desc_atomic_sync, int group_idx, int first_group_idx, bool tile_finished, bool tile_started, epilogue_args_t epilogue_args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Epilogue for stream_k.
Definition stream_k_op_xe.hpp:94
static constexpr uint32_t N_SG
Definition stream_k_op_xe.hpp:56
static constexpr mem_space mem_space_d
Definition stream_k_op_xe.hpp:69
mem_desc_d_t_ mem_desc_d_t
Definition stream_k_op_xe.hpp:38
static constexpr uint32_t sg_tile_n
Definition stream_k_op_xe.hpp:48
tile_shape_ tile_shape
Definition stream_k_op_xe.hpp:41
static constexpr uint32_t wg_tile_n
Definition stream_k_op_xe.hpp:46
typename tile_shape::work_group_t work_group_t
Definition stream_k_op_xe.hpp:44
static constexpr gpu_arch arch_tag
Definition stream_k_op_xe.hpp:36
typename epilogue_t::mem_desc_c_t mem_desc_c_t
Definition stream_k_op_xe.hpp:39
static constexpr msg_type msg_type_d_atomic
Definition stream_k_op_xe.hpp:71
typename mem_desc_d_t::dtype dtype_d
Definition stream_k_op_xe.hpp:60
static constexpr uint32_t barrier_count
Definition stream_k_op_xe.hpp:53
static constexpr uint32_t wg_size_y
Definition stream_k_op_xe.hpp:50
typename residual_op_t::arguments_t residual_op_args_t
Definition stream_k_op_xe.hpp:66
epilogue_t_ epilogue_t
Definition stream_k_op_xe.hpp:37
Is the element-wise reduce op functor, specialized for stream_k dispatch Load partial sum from scratc...
Definition tile_op_functor.hpp:826
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76