XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
softmax_fwd_xe.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
22#include "group/reduction/reduction.hpp"
23#include "group/softmax/api.hpp"
26
27namespace gpu::xetla::group {
28
29template <typename dtype_acc_, typename tile_shape_>
30class softmax_t<softmax_policy_fwd<dtype_acc_, gpu_arch::Xe>, tile_shape_> {
31
32public:
33 using tile_shape = tile_shape_;
34 using dtype_acc = dtype_acc_;
35 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
36 struct arguments_t {
38 inline arguments_t() = default;
39 inline arguments_t(dtype_acc sqrt_dk_inv_)
40 : sqrt_dk_inv(sqrt_dk_inv_) {}
41 };
42
43private:
44 using coord_t = mem_coord_t<2>;
45 using work_group_t = typename tile_shape::work_group_t;
46 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
47 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
48 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
49 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
50
51 using wg_reduce_max_t = group_reduce_t<dtype_acc, 1, sg_tile_m,
52 reduce_op::max, wg_size_x, true, gpu_arch::Xe>;
53 using wg_reduce_sum_t = group_reduce_t<dtype_acc, 1, sg_tile_m,
54 reduce_op::sum, wg_size_x, true, gpu_arch::Xe>;
55
56public:
57 struct get_barrier_count {
58 static constexpr uint32_t count = (wg_size_x > 1) ? wg_size_y : 0;
59 };
60
61 struct get_slm_size {
62 static constexpr uint32_t size = (wg_size_x > 1)
63 ? wg_size_y * wg_size_x * sg_tile_m * sizeof(dtype_acc)
64 : 0;
65 };
66
67 template <typename matAcc_t>
68 __XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc,
69 [[maybe_unused]] [[maybe_unused]] [[maybe_unused]] [[maybe_unused]] [[maybe_unused]] [[maybe_unused]] [[maybe_unused]] coord_t
70 coord,
71 const arguments_t &args, uint32_t slm_base = 0,
72 uint32_t nbarrier_base = 0) {
73 static_assert(std::is_same<typename matAcc_t::dtype, dtype_acc>::value,
74 "matAcc dtype_acc should match with dtype_acc");
75 int32_t sg_idx = g.get_id() % wg_size_x;
76 int32_t sg_idy = g.get_id() / wg_size_x;
77 uint32_t nbarrier_id = nbarrier_base + sg_idy;
78 uint32_t slm_base_addr
79 = slm_base + sg_idy * wg_size_x * sg_tile_m * sizeof(dtype_acc);
84 1>(matAcc);
85 wg_reduce_max_t wg_reduce_max(sg_idx, nbarrier_id, slm_base_addr);
86 xetla_vector<dtype_acc, sg_tile_m> group_max = wg_reduce_max(local_max);
87 if constexpr (wg_size_x > 1) { nbarrier.arrive(); }
88 subgroup::tile_broadcast_op<subgroup::tile_minus, matAcc_t>(
89 matAcc, group_max);
90 matAcc.reg = matAcc.reg * args.sqrt_dk_inv;
91 matAcc.reg = xetla_exp<dtype_acc>(matAcc.reg);
94 1>(matAcc);
95 wg_reduce_sum_t wg_reduce_sum(sg_idx, nbarrier_id, slm_base_addr);
96 if constexpr (wg_size_x > 1) { nbarrier.wait(); }
97 xetla_vector<dtype_acc, sg_tile_m> group_sum = wg_reduce_sum(local_sum);
98 subgroup::tile_broadcast_op<subgroup::tile_div, matAcc_t>(
99 matAcc, group_sum);
100 }
101};
102
103} // namespace gpu::xetla::group
__XETLA_API KERNEL_FUNC void operator()(work_group_t &g, matAcc_t &matAcc, coord_t coord, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
Definition softmax_fwd_xe.hpp:68
Definition api.hpp:27
#define __XETLA_API
Definition common.hpp:43
C++ API.
C++ API.
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
__XETLA_API std::enable_if_t<(dim==1), xetla_vector< dtype_out, mat_t::tile_size_y > > tile_reduce(mat_t &src)
Definition reduction.hpp:33
gpu_arch
Definition common.hpp:73
This is the group reduction.
Definition reduction_api.hpp:36
Definition softmax_policy.hpp:27
Definition memory_descriptor.hpp:30
Definition memory_descriptor.hpp:28
Define a workgroup scope for a specific problem shape.
Definition work_group.hpp:34
xetla nbarrier definition API.
Definition raw_send_nbarrier.hpp:43
__XETLA_API void init_nbarrier(uint8_t nbarrier_id, nbarrier_role role=nbarrier_role::producer_consumer)
Definition raw_send_nbarrier.hpp:55