XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
cooperative_load_helper.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::subgroup {
25
32template <typename matAcc_t_, mem_layout mem_layout_,
33 uint32_t num_cooperative_wg, gpu_arch arch_tag_, class enable = void>
35
37template <typename matAcc_t_, uint32_t num_cooperative_wg, gpu_arch arch_tag_>
39 num_cooperative_wg, arch_tag_,
40 std::enable_if_t<gpu_arch::Xe == arch_tag_>> {
41public:
42 static constexpr gpu_arch arch_tag = arch_tag_;
43 using matAcc_t = matAcc_t_;
44 using dtype = typename matAcc_t::dtype;
45 using tile_desc_t = typename matAcc_t::tile_desc;
46 static constexpr mem_layout layout = mem_layout::row_major;
47
48private:
49 // cooperative split, y dir first
50 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
51 "num_cooperative_wg should be power of 2");
52 //TODO
53 // static_assert(sg_tile_size_y * sg_tile_size_x / 16 / num_cooperative_wg, "");
54
55public:
56 static constexpr uint32_t src_block_size_x = tile_desc_t::block_size_x;
57 static constexpr uint32_t src_block_size_y = tile_desc_t::block_size_y;
58 static constexpr uint32_t src_tile_size_x = tile_desc_t::tile_size_x;
59 static constexpr uint32_t src_tile_size_y = tile_desc_t::tile_size_y;
60
61 static constexpr uint32_t coop_num_y
62 = gpu::xetla::subgroup::detail::gcd<num_cooperative_wg,
63 src_tile_size_y>::value;
64 static constexpr uint32_t coop_remain_num_x
65 = num_cooperative_wg / coop_num_y;
66 static constexpr bool has_redundant_wg
67 = (coop_remain_num_x * 16) > src_tile_size_x;
68 static constexpr uint32_t tile_size_y = src_tile_size_y / coop_num_y;
69 static constexpr uint32_t tile_size_x
70 = has_redundant_wg ? 16 : src_tile_size_x / coop_remain_num_x;
71 static constexpr uint32_t coop_num_x = src_tile_size_x / tile_size_x;
72
73public:
74 static constexpr uint32_t block_size_x
76 src_block_size_x>::value;
77 static constexpr uint32_t block_size_y
78 = (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y;
79
80 using co_tile_desc_t = subgroup::tile_desc_t<tile_size_x, tile_size_y,
81 block_size_x, block_size_y, reg_layout::tiled>;
82
83public:
84 inline cooperative_load_helper_t() = default;
85
86 inline static int32_t get_offset_x(uint32_t coop_id) {
87 return coop_id % coop_remain_num_x * tile_size_x;
88 }
89
90 inline static int32_t get_offset_y(uint32_t coop_id) {
91 return coop_id / coop_remain_num_x * tile_size_y;
92 }
93};
94
96template <typename matAcc_t_, uint32_t num_cooperative_wg, gpu_arch arch_tag_>
98 num_cooperative_wg, arch_tag_,
99 std::enable_if_t<gpu_arch::Xe == arch_tag_>> {
100public:
101 static constexpr gpu_arch arch_tag = arch_tag_;
102 using matAcc_t = matAcc_t_;
103 using dtype = typename matAcc_t::dtype;
104 using tile_desc_t = typename matAcc_t::tile_desc;
105 static constexpr mem_layout layout = mem_layout::col_major;
106
107private:
108 // cooperative split, y dir first
109 static_assert((num_cooperative_wg & (num_cooperative_wg - 1)) == 0,
110 "num_cooperative_wg should be power of 2");
111
112public:
113 static constexpr uint32_t src_block_size_x = tile_desc_t::block_size_x;
114 static constexpr uint32_t src_block_size_y = tile_desc_t::block_size_y;
115 static constexpr uint32_t src_tile_size_x = tile_desc_t::tile_size_x;
116 static constexpr uint32_t src_tile_size_y = tile_desc_t::tile_size_y;
117
118 static constexpr uint32_t coop_num_x
119 = gpu::xetla::subgroup::detail::gcd<num_cooperative_wg,
120 src_tile_size_x>::value;
121 static constexpr uint32_t coop_remain_num_y
122 = num_cooperative_wg / coop_num_x;
123 static constexpr bool has_redundant_wg
124 = (coop_remain_num_y * 16) > src_tile_size_y;
125 static constexpr uint32_t tile_size_x = src_tile_size_x / coop_num_x;
126 static constexpr uint32_t tile_size_y
127 = has_redundant_wg ? 16 : src_tile_size_y / coop_remain_num_y;
128 static constexpr uint32_t coop_num_y = src_tile_size_y / tile_size_y;
129
130public:
131 static constexpr uint32_t block_size_y
133 src_block_size_y>::value;
134 static constexpr uint32_t block_size_x
135 = (tile_size_x > src_block_size_x) ? src_block_size_x : tile_size_x;
136
137 using co_tile_desc_t = subgroup::tile_desc_t<tile_size_x, tile_size_y,
138 block_size_x, block_size_y, reg_layout::tiled>;
139
140public:
141 inline cooperative_load_helper_t() = default;
142
143 inline static int32_t get_offset_x(uint32_t coop_id) {
144 return coop_id / coop_remain_num_y * tile_size_x;
145 }
146
147 inline static int32_t get_offset_y(uint32_t coop_id) {
148 return coop_id % coop_remain_num_y * tile_size_y;
149 }
150};
151
152} // namespace gpu::xetla::subgroup
Helper to do the cooperative workgroups load.
Definition cooperative_load_helper.hpp:34
Definition limitation.hpp:457
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition common.hpp:80
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
static constexpr uint32_t tile_size_y
Definition api.hpp:66
static constexpr uint32_t block_size_x
Definition api.hpp:68
static constexpr uint32_t tile_size_x
Definition api.hpp:65
static constexpr uint32_t block_size_y
Definition api.hpp:69
C++ API.