XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ > Struct Template Reference

#include <softmax.hpp>

Classes

struct  arguments_t
 

Public Types

using dtype_in = dtype_in_
 
using dtype_out = dtype_out_
 
using tile_shape = tile_shape_
 
using softmax_tile_desc_t = subgroup::tile_desc_t< SIMD, block_height, SIMD, block_height, reg_layout::tiled >
 
using softmax_load_t = subgroup::tile_t< dtype_in, softmax_tile_desc_t >
 
using softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_in, mem_layout::row_major, mem_space_in >, softmax_tile_desc_t, subgroup::msg_type_v< softmax_tile_desc_t, mem_space_in >, gpu_arch::Xe >
 
using softmax_store_t = subgroup::tile_t< dtype_out, softmax_tile_desc_t >
 
using softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_out, mem_layout::row_major, mem_space_out >, softmax_tile_desc_t, subgroup::msg_type_v< softmax_tile_desc_t, mem_space_out >, gpu_arch::Xe >
 

Public Member Functions

__XETLA_API KERNEL_FUNC void operator() (sycl::nd_item< 3 > &item, arguments_t *args)
 

Static Public Attributes

static constexpr mem_space mem_space_in = mem_space_in_
 
static constexpr mem_space mem_space_out = mem_space_out_
 
static constexpr uint32_t sg_tile_m = tile_shape::sg_tile_size_y
 
static constexpr uint32_t sg_tile_n = tile_shape::sg_tile_size_x
 
static constexpr uint32_t wg_size_x = tile_shape::wg_size_x
 
static constexpr uint32_t wg_size_y = tile_shape::wg_size_y
 
static constexpr uint32_t wg_tile_m = sg_tile_m * wg_size_y
 
static constexpr uint32_t wg_tile_n = sg_tile_n * wg_size_x
 
static constexpr uint32_t SIMD = SIMD_
 
static constexpr uint32_t thread_num = thread_num_
 
static constexpr uint32_t softmax_size = softmax_size_
 
static constexpr uint32_t block_height = softmax_size / SIMD
 

Member Typedef Documentation

◆ dtype_in

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::dtype_in = dtype_in_

◆ dtype_out

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::dtype_out = dtype_out_

◆ softmax_load_payload_t

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>, softmax_tile_desc_t, subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>, gpu_arch::Xe>

◆ softmax_load_t

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_load_t = subgroup::tile_t<dtype_in, softmax_tile_desc_t>

◆ softmax_store_payload_t

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>, softmax_tile_desc_t, subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>, gpu_arch::Xe>

◆ softmax_store_t

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_store_t = subgroup::tile_t<dtype_out, softmax_tile_desc_t>

◆ softmax_tile_desc_t

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_tile_desc_t = subgroup::tile_desc_t<SIMD, block_height, SIMD, block_height, reg_layout::tiled>

◆ tile_shape

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::tile_shape = tile_shape_

Member Function Documentation

◆ operator()()

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
__XETLA_API KERNEL_FUNC void xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::operator() ( sycl::nd_item< 3 > &  item,
arguments_t args 
)
inline

Member Data Documentation

◆ block_height

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::block_height = softmax_size / SIMD
staticconstexpr

◆ mem_space_in

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr mem_space xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::mem_space_in = mem_space_in_
staticconstexpr

◆ mem_space_out

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr mem_space xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::mem_space_out = mem_space_out_
staticconstexpr

◆ sg_tile_m

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::sg_tile_m = tile_shape::sg_tile_size_y
staticconstexpr

◆ sg_tile_n

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::sg_tile_n = tile_shape::sg_tile_size_x
staticconstexpr

◆ SIMD

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::SIMD = SIMD_
staticconstexpr

◆ softmax_size

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_size = softmax_size_
staticconstexpr

◆ thread_num

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::thread_num = thread_num_
staticconstexpr

◆ wg_size_x

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::wg_size_x = tile_shape::wg_size_x
staticconstexpr

◆ wg_size_y

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::wg_size_y = tile_shape::wg_size_y
staticconstexpr

◆ wg_tile_m

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::wg_tile_m = sg_tile_m * wg_size_y
staticconstexpr

◆ wg_tile_n

template<typename dtype_in_ , typename dtype_out_ , typename tile_shape_ , mem_space mem_space_in_, mem_space mem_space_out_, uint32_t SIMD_, uint32_t thread_num_, uint32_t softmax_size_>
constexpr uint32_t xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::wg_tile_n = sg_tile_n * wg_size_x
staticconstexpr