XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
pre_processing_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/gemm/api.hpp"
23#include "group/gemm/common.hpp"
24
25namespace gpu::xetla::group {
26
29
31template <typename tile_shape_, gpu_arch arch_tag>
32class pre_processing_default_t<tile_shape_, arch_tag,
33 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
34 using tile_shape = tile_shape_;
35 using work_group_t = typename tile_shape::work_group_t;
36
37public:
38 struct arguments_t {};
39
40 inline pre_processing_default_t() = default;
41
42 inline pre_processing_default_t([[maybe_unused]] work_group_t &g,
43 [[maybe_unused]] arguments_t &args) {}
44
45 inline void init([[maybe_unused]] work_group_t &g,
46 [[maybe_unused]] arguments_t &args) {}
47
48 template <typename matA_acc_t, typename matB_acc_t, typename matA_t,
49 typename matB_t>
50 inline KERNEL_FUNC void operator()([[maybe_unused]] matA_acc_t &matA_acc,
51 [[maybe_unused]] matB_acc_t &matB_acc,
52 [[maybe_unused]] matA_t &matA, [[maybe_unused]] matB_t &matB) {}
53};
54
56template <typename tile_shape_, gpu_arch arch_tag>
57class pre_processing_matA_neg_filter_t<tile_shape_, arch_tag,
58 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
59 using tile_shape = tile_shape_;
60 using work_group_t = typename tile_shape::work_group_t;
61
62public:
63 struct arguments_t {};
64
66
67 inline pre_processing_matA_neg_filter_t([[maybe_unused]] work_group_t &g,
68 [[maybe_unused]] arguments_t &args) {}
69
70 inline void init([[maybe_unused]] work_group_t &g,
71 [[maybe_unused]] arguments_t &args) {}
72
73 template <typename matA_acc_t, typename matB_acc_t, typename matA_t,
74 typename matB_t>
75 inline KERNEL_FUNC void operator()([[maybe_unused]] matA_acc_t &matA_acc,
76 [[maybe_unused]] matB_acc_t &matB_acc,
77 [[maybe_unused]] matA_t &matA, [[maybe_unused]] matB_t &matB) {
78
79 using data_t = typename matA_acc_t::dtype;
80 if constexpr (sizeof(data_t) == 2) {
82 = matA_acc.reg.xetla_format<int16_t>() < 0;
83 matA_acc.reg.xetla_format<int16_t>().xetla_merge(0, mask);
84 }
85 if constexpr (sizeof(data_t) == 1) {
87 = matA_acc.reg.xetla_format<int8_t>() < 0;
88 matA_acc.reg.xetla_format<int8_t>().xetla_merge(0, mask);
89 }
90 if constexpr (sizeof(data_t) == 4) {
92 = matA_acc.reg.xetla_format<int32_t>() < 0;
93 matA_acc.reg.xetla_format<int32_t>().xetla_merge(0, mask);
94 }
95 }
96};
97
99
100} // namespace gpu::xetla::group
void init(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:45
pre_processing_default_t(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:42
KERNEL_FUNC void operator()(matA_acc_t &matA_acc, matB_acc_t &matB_acc, matA_t &matA, matB_t &matB)
Definition pre_processing_xe.hpp:50
pre_processing_matA_neg_filter_t(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:67
KERNEL_FUNC void operator()(matA_acc_t &matA_acc, matB_acc_t &matB_acc, matA_t &matA, matB_t &matB)
Definition pre_processing_xe.hpp:75
C++ API.
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39