XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gemm_polynomial.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#pragma once
18
19#include "xetla.hpp"
20
21namespace gpu::xetla::subgroup {
22
23template <typename dtype_, uint32_t N>
25 using dtype = dtype_;
27
28 struct arguments_t {
30 inline arguments_t() = default;
31 inline arguments_t(coeff_t coeff_) : coeff(coeff_) {}
32 };
33 template <typename matAcc_t, typename coord_t>
34 __XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc,
35 [[maybe_unused]] coord_t coord, arguments_t args,
36 [[maybe_unused]] uint32_t slm_base = 0,
37 [[maybe_unused]] uint32_t nbarrier_base = 0) {
38 using dtype_acc = typename matAcc_t::dtype;
39 // total flag register
40 constexpr uint32_t elems = 8 * 16;
41 constexpr uint32_t rounds = matAcc_t::tile_elems / elems;
42#pragma unroll
43 for (uint32_t r = 0; r < rounds; ++r) {
45 auto x = matAcc.reg.xetla_select<elems, 1>(elems * r);
46#pragma unroll
47 for (uint32_t i = 0; i < N; ++i) {
48 res = x * res;
49 res += static_cast<dtype_acc>(args.coeff[i]);
50 }
51 x = res;
52 }
53 constexpr uint32_t remained_elems = matAcc_t::tile_elems % elems;
54 if constexpr (remained_elems != 0) {
56 auto x = matAcc.reg.xetla_select<remained_elems, 1>(
57 elems * (matAcc_t::tile_elems / elems));
58#pragma unroll
59 for (uint32_t i = 0; i < N; ++i) {
60 res = x * res;
61 res += static_cast<dtype_acc>(args.coeff[i]);
62 }
63 x = res;
64 }
65 }
66};
67
68} // namespace gpu::xetla::subgroup
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:457
coeff_t coeff
Definition gemm_polynomial.hpp:29
arguments_t(coeff_t coeff_)
Definition gemm_polynomial.hpp:31
Definition gemm_polynomial.hpp:24
dtype_ dtype
Definition gemm_polynomial.hpp:25
__XETLA_API KERNEL_FUNC void operator()(matAcc_t &matAcc, coord_t coord, arguments_t args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition gemm_polynomial.hpp:34
xetla_vector< dtype, N > coeff_t
Definition gemm_polynomial.hpp:26
C++ API.