XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
quant_op_functor.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/tile/api.hpp"
29
30namespace gpu::xetla::subgroup {
31
35template <typename tile_op_t, gpu_arch arch_tag, class enable = void>
36struct dequant_op_t {};
38template <typename tile_op_t_, gpu_arch arch_tag>
39struct dequant_op_t<tile_op_t_, arch_tag,
40 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
41 //may need to add some limitations to tile_op used in dequant_op
42 using tile_op_t = tile_op_t_;
43 struct arguments_t {
44 typename tile_op_t::arguments_t tile_op_args;
45 inline arguments_t() = default;
46 inline arguments_t(typename tile_op_t::arguments_t tile_op_args_)
47 : tile_op_args(tile_op_args_) {}
48 };
49 template <typename mat_out_t, typename mat_in_t, typename coord_t>
50 __XETLA_API KERNEL_FUNC void operator()(mat_out_t &mat_out,
51 mat_in_t &mat_in, const coord_t &coord, const arguments_t &args) {
52 elemwise_cvt(mat_out, mat_in);
53 tile_op_t tile_op;
54 tile_op(mat_out, coord, args.tile_op_args);
55 }
56};
57
61template <typename tile_op_t, gpu_arch arch_tag, class enable = void>
62struct quant_op_t {};
64template <typename tile_op_t_, gpu_arch arch_tag>
65struct quant_op_t<tile_op_t_, arch_tag,
66 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
67 //may need to add some limitations to tile_op used in dequant_op
68 using tile_op_t = tile_op_t_;
69
70 struct arguments_t {
71 typename tile_op_t::arguments_t tile_op_args;
72 inline arguments_t() = default;
73 inline arguments_t(typename tile_op_t::arguments_t tile_op_args_)
74 : tile_op_args(tile_op_args_) {}
75 };
76 template <typename mat_out_t, typename mat_in_t,
77 typename dtype_sat = typename mat_out_t::dtype, typename coord_t>
78 __XETLA_API KERNEL_FUNC void operator()(mat_out_t &mat_out,
79 mat_in_t &mat_in, const coord_t &coord, const arguments_t &args) {
81 " mat_in and mat_out should be the same layout");
82 using matAcc_t = subgroup::tile_t<typename mat_in_t::dtype,
83 typename mat_out_t::tile_desc>;
84 using mat_sat_t
86
87 matAcc_t matAcc;
88 // to ensure there is no in-place changes in mat_in,
89 // and compiler will optimize this if there is no usage for mat_in.
90 elemwise_cvt(matAcc, mat_in);
91 tile_op_t tile_op;
92 tile_op(matAcc, coord, args.tile_op_args);
93 mat_sat_t mat_sat;
94 elemwise_cvt(mat_sat, matAcc);
95 elemwise_cvt(mat_out, mat_sat);
96 }
97
98 template <typename dtype_sat, typename matAcc_t, typename coord_t>
100 matAcc_t &matAcc, const coord_t &coord, const arguments_t &args) {
101 operator()<matAcc_t, matAcc_t, dtype_sat>(matAcc, matAcc, coord, args);
102 }
103};
104
105} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
C++ API.
Definition limitation.hpp:457
__XETLA_API std::enable_if_t<(T_src::register_layout !=reg_layout::linear) &&(T_dst::register_layout !=reg_layout::linear) &&is_same_layout< T_dst, T_src >::value &&(!is_floating_to_integer< T_dst, T_src >::value)> elemwise_cvt(T_dst &dst, T_src &src)
Is the element wise data conversion, the src and dst tile should have the same layout.
Definition op_function.hpp:40
C++ API.
arguments_t(typename tile_op_t::arguments_t tile_op_args_)
Definition quant_op_functor.hpp:46
__XETLA_API KERNEL_FUNC void operator()(mat_out_t &mat_out, mat_in_t &mat_in, const coord_t &coord, const arguments_t &args)
Definition quant_op_functor.hpp:50
Is the dequantization op functor.
Definition quant_op_functor.hpp:36
Definition common.hpp:213
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, const coord_t &coord, const arguments_t &args)
Definition quant_op_functor.hpp:99
__XETLA_API KERNEL_FUNC void operator()(mat_out_t &mat_out, mat_in_t &mat_in, const coord_t &coord, const arguments_t &args)
Definition quant_op_functor.hpp:78
arguments_t(typename tile_op_t::arguments_t tile_op_args_)
Definition quant_op_functor.hpp:73
Is the quantization op functor.
Definition quant_op_functor.hpp:62
Is a struct contains some register file.
Definition api.hpp:99
C++ API.
C++ API.