XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
tile_broadcast_op.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "subgroup/subgroup.hpp"
23
24namespace gpu::xetla::subgroup {
25
26struct tile_minus {
27 template <typename dtype, int vec_len>
29 xetla_vector<dtype, vec_len> vec_data, dtype data) {
30 return vec_data - data;
31 }
32};
33
34struct tile_div {
35 template <typename dtype, int vec_len>
37 xetla_vector<dtype, vec_len> vec_data, dtype data) {
38 return vec_data / data;
39 }
40};
41
42template <typename op, typename matAcc_t>
43inline void tile_broadcast_op(matAcc_t &matAcc,
45 static constexpr uint32_t tile_size_y = matAcc_t::tile_size_y;
46 static constexpr uint32_t tile_size_x = matAcc_t::tile_size_x;
47 static constexpr uint32_t block_size_y = matAcc_t::block_size_y;
48 static constexpr uint32_t block_size_x = matAcc_t::block_size_x;
49 static constexpr uint32_t block_elems = matAcc_t::block_elems;
50 static constexpr int32_t num_block_x = matAcc_t::num_block_x;
51 using dtype = typename matAcc_t::dtype;
52#pragma unroll
53 for (uint32_t i = 0; i < tile_size_y / block_size_y; i++) {
54#pragma unroll
55 for (uint32_t j = 0; j < num_block_x; j++) {
56 auto acc_reg = (matAcc.reg)
57 .xetla_select<block_elems, 1>(
58 (i * num_block_x + j) * block_elems);
59 auto acc_reg_2d
60 = acc_reg.xetla_format<dtype, block_size_y, block_size_x>();
61#pragma unroll
62 for (uint32_t row_i = 0; row_i < block_size_y; row_i++) {
63 acc_reg_2d.row(row_i) = op::template func<dtype, block_size_x>(
64 acc_reg_2d.row(row_i), data[block_size_y * i + row_i]);
65 }
66 }
67 }
68 // process the tail
69 if constexpr ((tile_size_y % block_size_y) != 0) {
70 constexpr uint32_t tail_start_y
71 = tile_size_y / block_size_y * block_size_y;
72 constexpr uint32_t tail_size_y = tile_size_y % block_size_y;
73 constexpr uint32_t tail_block_elems = tail_size_y * block_size_x;
74#pragma unroll
75 for (uint32_t j = 0; j < num_block_x; j++) {
76 auto acc_reg = (matAcc.reg)
77 .xetla_select<tail_block_elems, 1>(
78 tail_start_y * tile_size_x
79 + j * tail_block_elems);
80 auto acc_reg_2d
81 = acc_reg.xetla_format<dtype, tail_size_y, block_size_x>();
82#pragma unroll
83 for (uint32_t row_i = 0; row_i < tail_size_y; row_i++) {
84 acc_reg_2d.row(row_i) = op::template func<dtype, block_size_x>(
85 acc_reg_2d.row(row_i), data[tail_start_y + row_i]);
86 }
87 }
88 }
89}
90
91} // namespace gpu::xetla::subgroup
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
Definition limitation.hpp:457
void tile_broadcast_op(matAcc_t &matAcc, xetla_vector< typename matAcc_t::dtype, matAcc_t::tile_size_y > data)
Definition tile_broadcast_op.hpp:43
Definition tile_broadcast_op.hpp:34
static xetla_vector< dtype, vec_len > func(xetla_vector< dtype, vec_len > vec_data, dtype data)
Definition tile_broadcast_op.hpp:36
Definition tile_broadcast_op.hpp:26
static xetla_vector< dtype, vec_len > func(xetla_vector< dtype, vec_len > vec_data, dtype data)
Definition tile_broadcast_op.hpp:28