XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
quant_tile_op_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
25
26namespace gpu::xetla::group {
27
30
32template <typename dequant_op_t_, typename dtype_dequant_, typename tile_op_t_,
33 typename quant_op_t_, typename tile_shape_, typename mem_desc_c_t_,
34 gpu_arch arch_tag_>
35class epilogue_t<epilogue_policy_quant_op<dequant_op_t_, tile_op_t_,
36 quant_op_t_, arch_tag_, dtype_dequant_>,
37 tile_shape_, mem_desc_c_t_,
38 std::enable_if_t<(arch_tag_ == gpu_arch::Xe)>> {
39public:
40 using epilogue_policy = epilogue_policy_quant_op<dequant_op_t_, tile_op_t_,
41 quant_op_t_, arch_tag_, dtype_dequant_>;
46 using tile_shape = tile_shape_;
47 using mem_desc_c_t = mem_desc_c_t_;
48 static constexpr gpu_arch arch_tag = arch_tag_;
49 static constexpr uint32_t barrier_count = 0;
50 static constexpr uint32_t slm_size = mem_desc_c_t::is_local
51 ? tile_shape::wg_tile_size_x * tile_shape::wg_tile_size_y
52 : 0;
53
55 struct arguments_t {
57 typename dequant_op_t::arguments_t dequant_op_args;
58
61 typename tile_op_t::arguments_t tile_op_args;
62
64 typename quant_op_t::arguments_t quant_op_args;
65
67 inline arguments_t() = default;
68
74 inline arguments_t(typename dequant_op_t::arguments_t dequant_op_args_,
75 typename tile_op_t::arguments_t tile_op_args_,
76 typename quant_op_t::arguments_t quant_op_args_)
77 : dequant_op_args(dequant_op_args_)
78 , tile_op_args(tile_op_args_)
79 , quant_op_args(quant_op_args_) {}
80 inline arguments_t(const arguments_t &args)
81 : dequant_op_args(args.dequant_op_args)
82 , tile_op_args(args.tile_op_args)
83 , quant_op_args(args.quant_op_args) {}
84 // Be aware of the risks: Rule of three (copy constructor, copy assignment, destructor)
85 // Please check if you need to add self-define destructor
86 // ~arguments_t(){}
87 inline arguments_t &operator=(const arguments_t &args) {
88 this->dequant_op_args = args.dequant_op_args;
89 this->tile_op_args = args.tile_op_args;
90 this->quant_op_args = args.quant_op_args;
91 return *this;
92 }
93
99 inline void init(typename dequant_op_t::arguments_t dequant_op_args_,
100 typename tile_op_t::arguments_t tile_op_args_,
101 typename quant_op_t::arguments_t quant_op_args_) {
102 dequant_op_args = dequant_op_args_;
103 tile_op_args = tile_op_args_;
104 quant_op_args = quant_op_args_;
105 }
106 };
107
108private:
109 using work_group_t = typename tile_shape::work_group_t;
110 static constexpr uint32_t sg_tile_y = tile_shape::sg_tile_size_y;
111 static constexpr uint32_t sg_tile_x = tile_shape::sg_tile_size_x;
112 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
113 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
114 using dtype_c = typename mem_desc_c_t::dtype;
115 static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout;
116 static constexpr mem_space mem_space_c = mem_desc_c_t::space;
117
119 __XETLA_API static void update_sg_tile_tdesc(
120 work_group_t &g, mem_desc_c_t &mem_desc_c) {
121 int32_t sg_idx = g.get_id() % wg_size_x;
122 int32_t sg_idy = g.get_id() / wg_size_x;
123 int32_t tile_offset_x = sg_idx * sg_tile_x;
124 int32_t tile_offset_y = sg_idy * sg_tile_y;
125 mem_desc_c.update_coord(tile_offset_x, tile_offset_y);
126 }
127
128public:
129 static constexpr msg_type msg_type_c
130 = (mem_space_c == mem_space::global ? msg_type::block_2d
132
143 template <typename matAcc_t>
144 __XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc,
145 mem_desc_c_t mem_desc_c, arguments_t args = {},
146 uint32_t slm_base = 0, uint32_t nbarrier_base = 0) {
147 update_sg_tile_tdesc(g, mem_desc_c);
148 using mat_tile_desc = typename matAcc_t::tile_desc;
151
152 tile_op_t tile_op;
153 quant_op_t quant_op;
154 dequant_op_t dequant_op;
155 //dequantize
156 mat_dequant_t mat_dequant;
157 dequant_op(mat_dequant, matAcc, mem_desc_c.coord, args.dequant_op_args);
158 //post-op
159 tile_op(mat_dequant, mem_desc_c.coord, args.tile_op_args, slm_base,
160 nbarrier_base);
161 //quantize
162 matC_t matC;
163 quant_op(matC, mat_dequant, mem_desc_c.coord, args.quant_op_args);
164
165 using matC_payload_t = subgroup::mem_payload_t<mem_desc_c_t,
166 mat_tile_desc, msg_type_c, arch_tag>;
167 matC_payload_t matC_payload(mem_desc_c);
168 subgroup::tile_store<cache_hint::streaming, cache_hint::write_back>(
169 matC, matC_payload);
170 }
171};
172
174
175} // namespace gpu::xetla::group
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, mem_desc_c_t mem_desc_c, arguments_t args={}, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Default epilogue.
Definition quant_tile_op_xe.hpp:144
Is the epilogue functor.
Definition api.hpp:35
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
mem_space
Definition common.hpp:77
gpu_arch
Definition common.hpp:73
msg_type
Definition common.hpp:78
mem_layout
Definition common.hpp:76
Epilogue functor, specialized for quantization operator.
Definition epilogue_policy.hpp:53
quant_op_t_ quant_op_t
Definition epilogue_policy.hpp:56
dequant_op_t_ dequant_op_t
Definition epilogue_policy.hpp:54
dtype_dequant_ dtype_dequant
Definition epilogue_policy.hpp:58
tile_op_t_ tile_op_t
Definition epilogue_policy.hpp:55
arguments_t(typename dequant_op_t::arguments_t dequant_op_args_, typename tile_op_t::arguments_t tile_op_args_, typename quant_op_t::arguments_t quant_op_args_)
Constructs a new arguments t object.
Definition quant_tile_op_xe.hpp:74
void init(typename dequant_op_t::arguments_t dequant_op_args_, typename tile_op_t::arguments_t tile_op_args_, typename quant_op_t::arguments_t quant_op_args_)
Explicit initialization function.
Definition quant_tile_op_xe.hpp:99
Is to illustrate the memory information.
Definition api.hpp:44
Is a struct contains some register file.
Definition api.hpp:99
Define a workgroup scope for a specific problem shape.
Definition work_group.hpp:34
__XETLA_API uint32_t get_id()
Definition work_group.hpp:41