XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
reduction_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/reduction/reduction_api.hpp"
23
24namespace gpu::xetla::group {
25
26template <typename T, uint32_t SZ, uint32_t N, reduce_op Op, uint32_t N_SG,
27 bool is_all_reduce>
28struct group_reduce_t<T, SZ, N, Op, N_SG, is_all_reduce, gpu_arch::Xe> {
31 uint32_t slm_base;
32 uint32_t sg_id;
35 using local_ld_tile_desc = subgroup::tile_desc_t<N_SG * N, 1, N_SG * N, 1,
42 subgroup::msg_type_v<local_ld_tile_desc, mem_space::local>,
47 inline group_reduce_t() = default;
49 uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_) {
51 sg_id = sg_id_;
52 slm_base = slm_base_;
53 }
54 inline void init(uint32_t sg_id_ = 0, uint32_t nbarrier_id = 0,
55 uint32_t slm_base_ = 0) {
57 sg_id = sg_id_;
58 slm_base = slm_base_;
59 }
60 inline void set_slm_base(uint32_t slm_base_ = 0) { slm_base = slm_base_; }
61
64 local_st_t local_st;
65 local_st_payload_t local_st_payload;
66 xetla_vector<T, N> ret = sg_reduce(buffer);
67 local_st.reg = ret;
68 local_st_payload.init(slm_base, N_SG * N, 1, N_SG * N, sg_id * N, 0);
69 subgroup::tile_store(local_st, local_st_payload);
70 xetla_fence<memory_kind::shared_local>();
71 nbarrier.arrive();
72 nbarrier.wait();
73 if constexpr (is_all_reduce) {
74 local_ld_t local_ld;
75 local_ld_payload_t local_ld_payload(
76 slm_base, N_SG * N, 1, N_SG * N, 0, 0);
77 subgroup::tile_load(local_ld, local_ld_payload);
78 ret = recur_row_reduce<Op, T, N, N_SG>(local_ld.reg);
79 } else {
80 if (sg_id == 0) {
81 local_ld_t local_ld;
82 local_ld_payload_t local_ld_payload;
83 local_ld_payload.init(slm_base, N_SG * N, 1, N_SG * N, 0, 0);
84 subgroup::tile_load(local_ld, local_ld_payload);
85 ret = recur_row_reduce<Op, T, N, N_SG>(local_ld.reg);
86 }
87 }
88 return ret;
89 }
90};
91
92template <typename T, uint32_t SZ, uint32_t N, reduce_op Op, bool is_all_reduce>
93struct group_reduce_t<T, SZ, N, Op, 1, is_all_reduce, gpu_arch::Xe> {
94 inline group_reduce_t() = default;
95 inline group_reduce_t([[maybe_unused]] uint32_t sg_id_,
96 [[maybe_unused]] uint32_t nbarrier_id,
97 [[maybe_unused]] uint32_t slm_base_) {}
98 inline void init([[maybe_unused]] uint32_t sg_id_ = 0,
99 [[maybe_unused]] uint32_t nbarrier_id = 0,
100 [[maybe_unused]] uint32_t slm_base_ = 0) {}
101 inline void set_slm_base([[maybe_unused]] uint32_t slm_base_ = 0) {}
104 auto buffer_2d = buffer.xetla_format<T, N, SZ>();
106#pragma unroll
107 for (uint32_t i = 0; i < N; i++) {
108 ret[i] = xetla_reduce<T, T, SZ, Op>(buffer_2d.row(i));
109 }
110 return ret;
111 }
112};
113
114} // namespace gpu::xetla::group
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
__XETLA_API std::enable_if_t< detail::check_load_type< tile_t, payload_t >::is_global_2d_xe > tile_load(tile_t &tile, payload_t &payload)
This function loads data from 2D memory surface.
Definition load_xe.hpp:76
reduce_op
xetla reduce op
Definition common.hpp:217
gpu_arch
Definition common.hpp:73
void init(uint32_t sg_id_=0, uint32_t nbarrier_id=0, uint32_t slm_base_=0)
Definition reduction_xe.hpp:98
group_reduce_t(uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_)
Definition reduction_xe.hpp:95
void set_slm_base(uint32_t slm_base_=0)
Definition reduction_xe.hpp:101
KERNEL_FUNC xetla_vector< T, N > operator()(xetla_vector< T, N *SZ > buffer)
Definition reduction_xe.hpp:102
KERNEL_FUNC xetla_vector< T, N > operator()(xetla_vector< T, N *SZ > buffer)
Definition reduction_xe.hpp:62
void set_slm_base(uint32_t slm_base_=0)
Definition reduction_xe.hpp:60
void init(uint32_t sg_id_=0, uint32_t nbarrier_id=0, uint32_t slm_base_=0)
Definition reduction_xe.hpp:54
xetla_nbarrier_t< N_SG, N_SG, gpu_arch::Xe > nbarrier
Definition reduction_xe.hpp:30
group_reduce_t(uint32_t sg_id_, uint32_t nbarrier_id, uint32_t slm_base_)
Definition reduction_xe.hpp:48
This is the group reduction.
Definition reduction_api.hpp:36
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void arrive()
named barrier signal from subgroup.
Definition raw_send_nbarrier.hpp:65
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55
__XETLA_API void wait()
named barrier wait within subgroup.
Definition raw_send_nbarrier.hpp:76