XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
memory_descriptor.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
24
25namespace gpu::xetla {
26
27template <int dim = 2>
28struct mem_coord_t {};
29template <>
30struct mem_coord_t<2> {
31 int x;
32 int y;
33 inline mem_coord_t(int x_, int y_) : x(x_), y(y_) {}
34 inline mem_coord_t() = default;
35 inline mem_coord_t(const mem_coord_t<2> &coord) {
36 this->x = coord.x;
37 this->y = coord.y;
38 }
39 inline mem_coord_t<2> &operator=(const mem_coord_t<2> &coord) {
40 // Be aware of the risks:
41 // self_assign: No protection against the object assigning to itself.
42 // if (this == &coord){
43 // return *this;
44 // }
45
46 this->x = coord.x;
47 this->y = coord.y;
48 return *this;
49 }
50 inline void init(int x_, int y_) {
51 this->x = x_;
52 this->y = y_;
53 }
54};
55
56template <int dim = 2>
57struct mem_shape_t {};
58template <>
59struct mem_shape_t<2> {
60 uint32_t x;
61 uint32_t y;
62 uint32_t stride;
63 inline mem_shape_t() = default;
65 uint32_t shape_x_, uint32_t shape_y_, uint32_t row_stride_)
66 : x(shape_x_), y(shape_y_), stride(row_stride_) {}
68 this->x = shape.x;
69 this->y = shape.y;
70 this->stride = shape.stride;
71 }
73 // Be aware of the risks:
74 // self_assign: No protection against the object assigning to itself.
75 // if (this == &shape){
76 // return *this;
77 // }
78 this->x = shape.x;
79 this->y = shape.y;
80 this->stride = shape.stride;
81 return *this;
82 }
83 inline void init(
84 uint32_t shape_x_, uint32_t shape_y_, uint32_t row_stride_) {
85 this->x = shape_x_;
86 this->y = shape_y_;
87 this->stride = row_stride_;
88 }
89};
90
91template <typename dtype_, mem_space space_>
92struct mem_base_t {};
93template <typename dtype_>
94struct mem_base_t<dtype_, mem_space::global> {
95 using dtype = dtype_;
97 inline mem_base_t() = default;
98 inline mem_base_t(dtype *base_) : base(base_) {}
100 : base(mem_base.base) {}
102 const mem_base_t<dtype, mem_space::global> &mem_base) {
103 // Be aware of the risks:
104 // self_assign: No protection against the object assigning to itself.
105 // if (this == &mem_base){
106 // return *this;
107 // }
108 this->base = mem_base.base;
109 return *this;
110 }
111 inline void init(dtype *base_) { base = base_; }
112 inline void update(int offset) { base = base + offset; }
113};
114template <typename dtype_>
115struct mem_base_t<dtype_, mem_space::local> {
116 using dtype = dtype_;
117 uint32_t base;
118 inline mem_base_t() = default;
119 inline mem_base_t(uint32_t base_) { init(base_); }
121 init(mem_base.base);
122 }
124 const mem_base_t<dtype, mem_space::local> &mem_base) {
125 // Be aware of the risks:
126 // self_assign: No protection against the object assigning to itself.
127 // if (this == &mem_base){
128 // return *this;
129 // }
130 init(mem_base.base);
131 return *this;
132 }
133 inline void init(uint32_t base_) { base = base_; }
134 inline void update(int offset) { init(base + offset * sizeof(dtype)); }
135};
136
137template <typename dtype_, mem_layout layout_, mem_space space_,
138 uint32_t alignment_ = 8, int dim_ = 2>
139struct mem_desc_t {};
140
141template <typename dtype_, mem_layout layout_, mem_space space_,
142 uint32_t alignment_>
143struct mem_desc_t<dtype_, layout_, space_, alignment_, 2> {
144 using dtype = dtype_;
145 static constexpr mem_layout layout = layout_;
146 static constexpr mem_space space = space_;
147 static constexpr int dim = 2;
148 static constexpr uint32_t alignment = alignment_;
149 static constexpr uint32_t alignment_in_bytes = alignment_ * sizeof(dtype);
150
151 static constexpr bool is_col_major = layout == mem_layout::col_major;
152 static constexpr bool is_local = space == mem_space::local;
156
158
159 inline mem_desc_t() = default;
160 inline mem_desc_t(base_t base_, shape_t shape_, coord_t coord_)
161 : shape(shape_), coord(coord_), base(base_) {}
162 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
163 // Please check if you need to add self-define destructor
164 // inline ~mem_desc_t(){}
165 inline mem_desc_t(const this_type_t &mem_desc)
166 : shape(mem_desc.shape), coord(mem_desc.coord), base(mem_desc.base) {}
167
168 inline this_type_t &operator=(const this_type_t &mem_desc) {
169 this->base = mem_desc.base;
170 this->shape = mem_desc.shape;
171 this->coord = mem_desc.coord;
172 return *this;
173 }
174 inline void init(base_t base_, shape_t shape_, coord_t coord_) {
175 base = base_;
176 shape = shape_;
177 coord = coord_;
178 }
179 inline void update_coord(int offset_x, int offset_y) {
180 coord.x += offset_x;
181 coord.y += offset_y;
182 }
183 inline void update_coord_x(int offset_x) { coord.x += offset_x; }
184 inline void update_coord_y(int offset_y) { coord.y += offset_y; }
186 uint32_t width = is_col_major ? shape.y : shape.x;
187 uint32_t height = is_col_major ? shape.x : shape.y;
188 uint32_t pitch = shape.stride;
189 int coord_x = is_col_major ? coord.y : coord.x;
190 int coord_y = is_col_major ? coord.x : coord.y;
191 return xetla_get_tdesc<dtype>(
192 base.base, width, height, pitch, coord_x, coord_y);
193 }
194
198};
199
200} // namespace gpu::xetla
C++ API.
xetla_vector< uint32_t, 16 > xetla_tdescriptor
Description of nd tensor descriptor for load and store.
Definition base_types.hpp:155
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
mem_layout
Definition common.hpp:76
void init(dtype *base_)
Definition memory_descriptor.hpp:111
mem_base_t(const mem_base_t< dtype, mem_space::global > &mem_base)
Definition memory_descriptor.hpp:99
dtype_ dtype
Definition memory_descriptor.hpp:95
void update(int offset)
Definition memory_descriptor.hpp:112
dtype * base
Definition memory_descriptor.hpp:96
mem_base_t< dtype, mem_space::global > & operator=(const mem_base_t< dtype, mem_space::global > &mem_base)
Definition memory_descriptor.hpp:101
mem_base_t(dtype *base_)
Definition memory_descriptor.hpp:98
dtype_ dtype
Definition memory_descriptor.hpp:116
mem_base_t< dtype, mem_space::local > & operator=(const mem_base_t< dtype, mem_space::local > &mem_base)
Definition memory_descriptor.hpp:123
mem_base_t(uint32_t base_)
Definition memory_descriptor.hpp:119
mem_base_t(const mem_base_t< dtype, mem_space::local > &mem_base)
Definition memory_descriptor.hpp:120
void init(uint32_t base_)
Definition memory_descriptor.hpp:133
void update(int offset)
Definition memory_descriptor.hpp:134
uint32_t base
Definition memory_descriptor.hpp:117
Definition memory_descriptor.hpp:92
void init(int x_, int y_)
Definition memory_descriptor.hpp:50
mem_coord_t(const mem_coord_t< 2 > &coord)
Definition memory_descriptor.hpp:35
int y
Definition memory_descriptor.hpp:32
int x
Definition memory_descriptor.hpp:31
mem_coord_t< 2 > & operator=(const mem_coord_t< 2 > &coord)
Definition memory_descriptor.hpp:39
mem_coord_t(int x_, int y_)
Definition memory_descriptor.hpp:33
Definition memory_descriptor.hpp:28
xetla_tdescriptor get_tdesc()
Definition memory_descriptor.hpp:185
void update_coord_x(int offset_x)
Definition memory_descriptor.hpp:183
shape_t shape
Definition memory_descriptor.hpp:195
mem_desc_t(base_t base_, shape_t shape_, coord_t coord_)
Definition memory_descriptor.hpp:160
mem_desc_t(const this_type_t &mem_desc)
Definition memory_descriptor.hpp:165
this_type_t & operator=(const this_type_t &mem_desc)
Definition memory_descriptor.hpp:168
base_t base
Definition memory_descriptor.hpp:197
dtype_ dtype
Definition memory_descriptor.hpp:144
coord_t coord
Definition memory_descriptor.hpp:196
void update_coord(int offset_x, int offset_y)
Definition memory_descriptor.hpp:179
void update_coord_y(int offset_y)
Definition memory_descriptor.hpp:184
void init(base_t base_, shape_t shape_, coord_t coord_)
Definition memory_descriptor.hpp:174
Definition memory_descriptor.hpp:139
void init(uint32_t shape_x_, uint32_t shape_y_, uint32_t row_stride_)
Definition memory_descriptor.hpp:83
mem_shape_t< 2 > & operator=(const mem_shape_t< 2 > &shape)
Definition memory_descriptor.hpp:72
mem_shape_t(uint32_t shape_x_, uint32_t shape_y_, uint32_t row_stride_)
Definition memory_descriptor.hpp:64
uint32_t x
Definition memory_descriptor.hpp:60
uint32_t y
Definition memory_descriptor.hpp:61
uint32_t stride
Definition memory_descriptor.hpp:62
mem_shape_t(const mem_shape_t< 2 > &shape)
Definition memory_descriptor.hpp:67
Definition memory_descriptor.hpp:57
Definition dict.hpp:59