23#include "group/gemm/common.hpp"
31template <
typename tile_shape_, gpu_arch arch_tag>
33 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
34 using tile_shape = tile_shape_;
35 using work_group_t =
typename tile_shape::work_group_t;
38 struct arguments_t {};
43 [[maybe_unused]] arguments_t &args) {}
45 inline void init([[maybe_unused]] work_group_t &g,
46 [[maybe_unused]] arguments_t &args) {}
48 template <
typename matA_acc_t,
typename matB_acc_t,
typename matA_t,
51 [[maybe_unused]] matB_acc_t &matB_acc,
52 [[maybe_unused]] matA_t &matA, [[maybe_unused]] matB_t &matB) {}
56template <
typename tile_shape_, gpu_arch arch_tag>
58 std::enable_if_t<(arch_tag == gpu_arch::Xe)>> {
59 using tile_shape = tile_shape_;
60 using work_group_t =
typename tile_shape::work_group_t;
63 struct arguments_t {};
68 [[maybe_unused]] arguments_t &args) {}
70 inline void init([[maybe_unused]] work_group_t &g,
71 [[maybe_unused]] arguments_t &args) {}
73 template <
typename matA_acc_t,
typename matB_acc_t,
typename matA_t,
76 [[maybe_unused]] matB_acc_t &matB_acc,
77 [[maybe_unused]] matA_t &matA, [[maybe_unused]] matB_t &matB) {
79 using data_t =
typename matA_acc_t::dtype;
80 if constexpr (
sizeof(data_t) == 2) {
82 = matA_acc.reg.xetla_format<int16_t>() < 0;
83 matA_acc.reg.xetla_format<int16_t>().
xetla_merge(0, mask);
85 if constexpr (
sizeof(data_t) == 1) {
87 = matA_acc.reg.xetla_format<int8_t>() < 0;
88 matA_acc.reg.xetla_format<int8_t>().
xetla_merge(0, mask);
90 if constexpr (
sizeof(data_t) == 4) {
92 = matA_acc.reg.xetla_format<int32_t>() < 0;
93 matA_acc.reg.xetla_format<int32_t>().
xetla_merge(0, mask);
void init(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:45
pre_processing_default_t(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:42
pre_processing_default_t()=default
KERNEL_FUNC void operator()(matA_acc_t &matA_acc, matB_acc_t &matB_acc, matA_t &matA, matB_t &matB)
Definition pre_processing_xe.hpp:50
pre_processing_matA_neg_filter_t()=default
pre_processing_matA_neg_filter_t(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:67
void init(work_group_t &g, arguments_t &args)
Definition pre_processing_xe.hpp:70
KERNEL_FUNC void operator()(matA_acc_t &matA_acc, matB_acc_t &matB_acc, matA_t &matA, matB_t &matB)
Definition pre_processing_xe.hpp:75
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
Gemm default pre_processing functor.
Definition api.hpp:33
Gemm pre_processing functor with applying relu op to matA.
Definition api.hpp:39