XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
softmax.hpp
Go to the documentation of this file.
1/*******************************************************************************
2 * Copyright (c) 2022-2023 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#pragma once
18
19#include "xetla.hpp"
20
21using namespace gpu::xetla;
22using namespace gpu::xetla::group;
23using namespace gpu::xetla::subgroup;
24
25template <typename dtype_in_, typename dtype_out_, typename tile_shape_,
26 mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_,
27 uint32_t thread_num_, uint32_t softmax_size_>
29 using dtype_in = dtype_in_;
30 using dtype_out = dtype_out_;
31 using tile_shape = tile_shape_;
32
33 static constexpr mem_space mem_space_in = mem_space_in_;
34 static constexpr mem_space mem_space_out = mem_space_out_;
35
36 static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y;
37 static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x;
38 static constexpr uint32_t wg_size_x = tile_shape::wg_size_x;
39 static constexpr uint32_t wg_size_y = tile_shape::wg_size_y;
40 static constexpr uint32_t wg_tile_m = sg_tile_m * wg_size_y;
41 static constexpr uint32_t wg_tile_n = sg_tile_n * wg_size_x;
42
43 static constexpr uint32_t SIMD = SIMD_;
44 static constexpr uint32_t thread_num = thread_num_;
45 static constexpr uint32_t softmax_size = softmax_size_;
46 static constexpr uint32_t block_height = softmax_size / SIMD;
47
48 // each tile load one row from SLM
49 // change data surface to imp 2D-block load
50 // SIMD is the tile width
51 // 2 * blockHeight is the tile height
52 // SIMD * 2 * blockHeight equals to the elements number of one row
54 block_height, reg_layout::tiled>;
59 subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>,
60 gpu_arch::Xe>;
61
62 // this tile will store the softmax result to global memory
67 subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>,
68 gpu_arch::Xe>;
69
70 struct arguments_t {
71 // available while original data is from SLM
72 uint32_t data_in_base;
73 // available while processed data is to SLM
74 uint32_t data_out_base;
75 // available while original data is from global memory
77 // available while processed data is to global memory
79 };
80
82 sycl::nd_item<3> &item, arguments_t *args) {
83
84 softmax_load_t softmax_load;
85 softmax_load_payload_t softmax_load_payload;
86 softmax_store_t softmax_store;
87 softmax_store_payload_t softmax_store_payload;
88
89 uint32_t local_offset_y = block_height * item.get_local_linear_id();
90
91 // read original data from SLM:
92 // reshape 1 * 512 to 32 * 16
93 // each thread load two rows:
94 // thread#i will load row#i and row#(i + 32)
95
96 softmax_load_payload.init(args->data_in_base, SIMD,
97 block_height * wg_tile_m, SIMD, 0, local_offset_y);
98 softmax_store_payload.init(args->data_out_base, SIMD,
99 block_height * wg_tile_m, SIMD, 0, local_offset_y);
100
102
103 uint32_t inner_loop_count = (wg_tile_m % thread_num == 0)
105 : (wg_tile_m / thread_num) + 1;
106
107 for (uint32_t row = 0; row < inner_loop_count; ++row) {
108 subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
109 softmax_load, softmax_load_payload);
110 softmax_load_payload.template update_tdesc<tdesc_update_dir::y_dir>(
112
113 row_data_32 = softmax_load.reg.xetla_select<softmax_size, 1>(0);
114
115 // get max
116 float xmax = hmax<float, float, softmax_size>(row_data_32);
117
118 // get exp_sum
119 row_data_32 -= xmax;
120 row_data_32 = exp(row_data_32);
121 float exp_sum = sum<float, float, softmax_size>(row_data_32);
122
123 // get softmax elementwise result
124 row_data_32 /= exp_sum;
125
126 softmax_store.reg.xetla_select<softmax_size, 1>(0)
127 = xetla_cvt<dtype_out, float, softmax_size>(row_data_32);
128
129 tile_store(softmax_store, softmax_store_payload);
130 softmax_store_payload
131 .template update_tdesc<tdesc_update_dir::y_dir>(
133 }
134 } // void run()
135}; // struct xetla_softmax_fwd_t
#define __XETLA_API
Definition common.hpp:43
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
Definition limitation.hpp:607
Definition limitation.hpp:457
__XETLA_API std::enable_if_t< detail::check_store_type< tile_t, payload_t >::is_global_2d_xe > tile_store(tile_t &tile, payload_t &payload)
Is the func storing data from register file to global memory.
Definition store_xe.hpp:91
Definition arch_config.hpp:24
mem_space
Definition common.hpp:77
Definition memory_descriptor.hpp:139
Is to illustrate the memory information.
Definition api.hpp:44
Is to illustrate the tile information about a sub matrix.
Definition api.hpp:64
Is a struct contains some register file.
Definition api.hpp:99
xetla_vector< dtype, tile_desc::tile_elems > reg
Definition api.hpp:102
Definition softmax.hpp:70
uint32_t data_in_base
Definition softmax.hpp:72
dtype_in * data_in_ptr
Definition softmax.hpp:76
dtype_out * data_out_ptr
Definition softmax.hpp:78
uint32_t data_out_base
Definition softmax.hpp:74
Definition softmax.hpp:28
dtype_out_ dtype_out
Definition softmax.hpp:30
static constexpr uint32_t wg_size_x
Definition softmax.hpp:38
dtype_in_ dtype_in
Definition softmax.hpp:29
tile_shape_ tile_shape
Definition softmax.hpp:31
static constexpr uint32_t block_height
Definition softmax.hpp:46
static constexpr mem_space mem_space_out
Definition softmax.hpp:34
subgroup::tile_desc_t< SIMD, block_height, SIMD, block_height, reg_layout::tiled > softmax_tile_desc_t
Definition softmax.hpp:54
static constexpr uint32_t sg_tile_n
Definition softmax.hpp:37
static constexpr uint32_t wg_tile_m
Definition softmax.hpp:40
static constexpr uint32_t SIMD
Definition softmax.hpp:43
static constexpr uint32_t thread_num
Definition softmax.hpp:44
__XETLA_API KERNEL_FUNC void operator()(sycl::nd_item< 3 > &item, arguments_t *args)
Definition softmax.hpp:81
static constexpr uint32_t sg_tile_m
Definition softmax.hpp:36
static constexpr uint32_t wg_size_y
Definition softmax.hpp:39
static constexpr uint32_t softmax_size
Definition softmax.hpp:45
static constexpr mem_space mem_space_in
Definition softmax.hpp:33
static constexpr uint32_t wg_tile_n
Definition softmax.hpp:41
C++ API.