#include <softmax.hpp>
Classes | |
| struct | arguments_t |
Public Types | |
| using | dtype_in = dtype_in_ |
| using | dtype_out = dtype_out_ |
| using | tile_shape = tile_shape_ |
| using | softmax_tile_desc_t = subgroup::tile_desc_t< SIMD, block_height, SIMD, block_height, reg_layout::tiled > |
| using | softmax_load_t = subgroup::tile_t< dtype_in, softmax_tile_desc_t > |
| using | softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_in, mem_layout::row_major, mem_space_in >, softmax_tile_desc_t, subgroup::msg_type_v< softmax_tile_desc_t, mem_space_in >, gpu_arch::Xe > |
| using | softmax_store_t = subgroup::tile_t< dtype_out, softmax_tile_desc_t > |
| using | softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_out, mem_layout::row_major, mem_space_out >, softmax_tile_desc_t, subgroup::msg_type_v< softmax_tile_desc_t, mem_space_out >, gpu_arch::Xe > |
Public Member Functions | |
| __XETLA_API KERNEL_FUNC void | operator() (sycl::nd_item< 3 > &item, arguments_t *args) |
Static Public Attributes | |
| static constexpr mem_space | mem_space_in = mem_space_in_ |
| static constexpr mem_space | mem_space_out = mem_space_out_ |
| static constexpr uint32_t | sg_tile_m = tile_shape::sg_tile_size_y |
| static constexpr uint32_t | sg_tile_n = tile_shape::sg_tile_size_x |
| static constexpr uint32_t | wg_size_x = tile_shape::wg_size_x |
| static constexpr uint32_t | wg_size_y = tile_shape::wg_size_y |
| static constexpr uint32_t | wg_tile_m = sg_tile_m * wg_size_y |
| static constexpr uint32_t | wg_tile_n = sg_tile_n * wg_size_x |
| static constexpr uint32_t | SIMD = SIMD_ |
| static constexpr uint32_t | thread_num = thread_num_ |
| static constexpr uint32_t | softmax_size = softmax_size_ |
| static constexpr uint32_t | block_height = softmax_size / SIMD |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::dtype_in = dtype_in_ |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::dtype_out = dtype_out_ |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_load_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>, softmax_tile_desc_t, subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>, gpu_arch::Xe> |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_load_t = subgroup::tile_t<dtype_in, softmax_tile_desc_t> |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_store_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>, softmax_tile_desc_t, subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>, gpu_arch::Xe> |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_store_t = subgroup::tile_t<dtype_out, softmax_tile_desc_t> |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::softmax_tile_desc_t = subgroup::tile_desc_t<SIMD, block_height, SIMD, block_height, reg_layout::tiled> |
| using xetla_softmax_fwd_t< dtype_in_, dtype_out_, tile_shape_, mem_space_in_, mem_space_out_, SIMD_, thread_num_, softmax_size_ >::tile_shape = tile_shape_ |
|
inline |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |