#include <dropout_mask_gen.hpp>
Classes | |
| struct | arguments_t |
Public Types | |
| using | dtype_mask = dtype_mask_ |
| using | load_store_attr = typename arch_attr_t< arch_ >::template load_store_attr< msg_type::block_2d > |
| using | mask_out_tile_desc_t = subgroup::tile_desc_t< tile_size_x, tile_size_y, block_size_x, block_size_y, reg_layout::tiled > |
| using | mask_out_tile_t = subgroup::tile_t< dtype_mask, mask_out_tile_desc_t > |
| using | mask_out_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_mask, mem_layout::row_major, mem_space::global >, mask_out_tile_desc_t,(sg_tile_m==1) ? msg_type::block_1d :msg_type::block_2d, gpu_arch::Xe > |
Public Member Functions | |
| __XETLA_API KERNEL_FUNC void | operator() (arguments_t *args, uint32_t wg_idx, uint32_t wg_idy, uint32_t sg_idx, uint32_t sg_idy, uint32_t linear_idx) |
Static Public Attributes | |
| static constexpr uint32_t | wg_tile_n = wg_tile_n_ |
| static constexpr uint32_t | wg_tile_m = wg_tile_m_ |
| static constexpr uint32_t | sg_tile_n = sg_tile_n_ |
| static constexpr uint32_t | sg_tile_m = sg_tile_m_ |
| static constexpr uint32_t | random_simd = random_simd_ |
| static constexpr uint32_t | wg_size_x = (wg_tile_n + sg_tile_n - 1) / sg_tile_n |
| static constexpr uint32_t | wg_size_y = (wg_tile_m + sg_tile_m - 1) / sg_tile_m |
| static constexpr uint32_t | max_store_width_in_bytes = load_store_attr::max_store_width_in_bytes |
| static constexpr uint32_t | max_store_width_in_elem = max_store_width_in_bytes / sizeof(dtype_mask) |
| static constexpr uint32_t | max_store_height_in_elem = load_store_attr::max_store_height_in_elem |
| static constexpr uint32_t | tile_size_x = sg_tile_n |
| static constexpr uint32_t | tile_size_y = sg_tile_m |
| static constexpr uint32_t | block_size_x |
| static constexpr uint32_t | block_size_y |
| static constexpr uint32_t | tile_size = tile_size_x * tile_size_y |
| dtype_mask_ | |
| wg_tile_n_ | |
| wg_tile_m_ | |
| sg_tile_n_ | |
| sg_tile_m_ | |
| random_simd_ | |
| arch_ |
| using gpu::xetla::group::mask_gen_t< dtype_mask_, wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, random_simd_, arch_ >::dtype_mask = dtype_mask_ |
| using gpu::xetla::group::mask_gen_t< dtype_mask_, wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, random_simd_, arch_ >::load_store_attr = typename arch_attr_t< arch_>::template load_store_attr<msg_type::block_2d> |
| using gpu::xetla::group::mask_gen_t< dtype_mask_, wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, random_simd_, arch_ >::mask_out_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_mask, mem_layout::row_major, mem_space::global>, mask_out_tile_desc_t, (sg_tile_m == 1) ? msg_type::block_1d : msg_type::block_2d, gpu_arch::Xe> |
| using gpu::xetla::group::mask_gen_t< dtype_mask_, wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, random_simd_, arch_ >::mask_out_tile_desc_t = subgroup::tile_desc_t<tile_size_x, tile_size_y, block_size_x, block_size_y, reg_layout::tiled> |
| using gpu::xetla::group::mask_gen_t< dtype_mask_, wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, random_simd_, arch_ >::mask_out_tile_t = subgroup::tile_t<dtype_mask, mask_out_tile_desc_t> |
|
inline |
| args | |
| wg_idx | |
| wg_idy | |
| sg_idx | |
| sg_idy | |
| linear_idx |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |