XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
row_reduce_store_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::group {
25
26template <typename dtype_acc, typename dtype_out, uint32_t row_size,
27 uint32_t wg_size_x, uint32_t wg_size_y, uint32_t max_simd_len>
28struct group_row_reduce_store_t<dtype_acc, dtype_out, row_size, wg_size_x,
29 wg_size_y, max_simd_len, gpu_arch::Xe> {
30 static constexpr uint32_t block_size_x
32 static_assert(block_size_x >= 8,
33 "if block_size_x is less than 8, the efficiency will be low. "
34 "Please choose another tile_size_x");
35 static constexpr uint32_t num_block_x = row_size / block_size_x;
36 static_assert((num_block_x < wg_size_y) || (num_block_x % wg_size_y == 0),
37 "num_block_x should be less than wg_size_y or num_block_x should "
38 "be a multiple of wg_size_y");
39 static constexpr uint32_t num_block_per_thd
40 = (num_block_x < wg_size_y) ? 1 : num_block_x / wg_size_y;
41 static constexpr uint32_t cooperative_thd_num
42 = (num_block_x < wg_size_y) ? num_block_x : wg_size_y;
43 static constexpr uint32_t local_tile_size_x
44 = num_block_per_thd * block_size_x;
46 block_size_x, 1, reg_layout::tiled>;
51 subgroup::msg_type_v<local_st_tile_desc_t, mem_space::local>,
54 wg_size_y, block_size_x, wg_size_y, reg_layout::tiled>;
59 subgroup::msg_type_v<local_ld_tile_desc_t, mem_space::local>,
61
62 //If the local tile size is small, we still can use 2D block store
63 using global_st_tile_desc_t = subgroup::tile_desc_t<local_tile_size_x, 1,
64 block_size_x, 1, reg_layout::tiled>;
69 (local_tile_size_x * sizeof(dtype_out) > 64) ? msg_type::block_1d
77 uint32_t sg_idx;
78 uint32_t sg_idy;
79 inline void init(uint32_t sg_idx_ = 0, uint32_t sg_idy_ = 0,
80 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
81 sg_idx = sg_idx_;
82 sg_idy = sg_idy_;
83 nbarrier.init_nbarrier(
84 sg_idx + nbarrier_base, nbarrier_role::producer_consumer);
85 local_st_payload.init(slm_base, row_size * wg_size_x, wg_size_y,
86 row_size * wg_size_x, row_size * sg_idx, sg_idy);
87 local_ld_payload.init(slm_base, row_size * wg_size_x, wg_size_y,
88 row_size * wg_size_x,
89 row_size * sg_idx + local_tile_size_x * sg_idy, 0);
90 }
91
92 inline KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width,
93 uint32_t st_height, uint32_t st_pitch, int start_n_base,
94 int start_m_base, xetla_vector<dtype_acc, row_size> buffer) {
95 local_st.reg = buffer;
96 subgroup::tile_store(local_st, local_st_payload);
97 xetla_fence<memory_kind::shared_local>();
98 nbarrier.arrive();
99 nbarrier.wait();
100 if (sg_idy < cooperative_thd_num) {
101 subgroup::tile_load(local_ld, local_ld_payload);
102 global_st_t global_st;
103 global_st_payload_t global_st_payload(ptr, st_width, st_height,
104 st_pitch, start_n_base + local_tile_size_x * sg_idy,
105 start_m_base);
106 global_st.reg = subgroup::tile_reduce<reduce_op::sum, dtype_out,
107 dtype_acc, 0>(local_ld);
108 subgroup::tile_store<cache_hint::uncached>(
109 global_st, global_st_payload);
110 }
111 nbarrier.arrive();
112 nbarrier.wait();
113 }
114};
115
116template <typename dtype_acc, typename dtype_out, uint32_t row_size,
117 uint32_t wg_size_x, uint32_t max_simd_len>
118struct group_row_reduce_store_t<dtype_acc, dtype_out, row_size, wg_size_x, 1,
119 max_simd_len, gpu_arch::Xe> {
120 static constexpr uint32_t block_size_x
122
124 block_size_x, 1, reg_layout::tiled>;
129 (row_size * sizeof(dtype_out) > 64) ? msg_type::block_1d
132 inline void init([[maybe_unused]] uint32_t sg_idx_ = 0,
133 [[maybe_unused]] uint32_t sg_idy_ = 0,
134 [[maybe_unused]] uint32_t slm_base = 0,
135 [[maybe_unused]] uint32_t nbarrier_base = 0) {}
136
137 inline KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width,
138 uint32_t st_height, uint32_t st_pitch, int start_n_base,
139 int start_m_base, xetla_vector<dtype_acc, row_size> buffer) {
140 global_st_t global_st;
141 global_st_payload_t global_st_payload;
142 global_st.reg = xetla_cvt<dtype_out, dtype_acc, row_size>(buffer);
143 global_st_payload.init(
144 ptr, st_width, st_height, st_pitch, start_n_base, start_m_base);
145 subgroup::tile_store<cache_hint::uncached>(
146 global_st, global_st_payload);
147 }
148};
149
150} // namespace gpu::xetla::group
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
void init(uint32_t sg_idx_=0, uint32_t sg_idy_=0, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition row_reduce_store_xe.hpp:132
KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width, uint32_t st_height, uint32_t st_pitch, int start_n_base, int start_m_base, xetla_vector< dtype_acc, row_size > buffer)
Definition row_reduce_store_xe.hpp:137
xetla_nbarrier_t< wg_size_y, wg_size_y, gpu_arch::Xe > nbarrier
Definition row_reduce_store_xe.hpp:72
void init(uint32_t sg_idx_=0, uint32_t sg_idy_=0, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition row_reduce_store_xe.hpp:79
KERNEL_FUNC void operator()(dtype_out *ptr, uint32_t st_width, uint32_t st_height, uint32_t st_pitch, int start_n_base, int start_m_base, xetla_vector< dtype_acc, row_size > buffer)
Definition row_reduce_store_xe.hpp:92
This is the group row reduction(reduce_sum) + cooperative write out.
Definition reduction_api.hpp:39
Definition memory_descriptor.hpp:139
Definition common.hpp:80
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76