#include <kernel_func.hpp>
Public Types | |
| using | perf_tuning_knob = perf_tuning_knob_t< sg_tile_k, prefetch_distance, periodic_sync_interval > |
| using | compute_attr = group::compute_attr_t< T, T, Act_T > |
| using | compute_policy = compute_policy_default_xmx< compute_attr, perf_tuning_knob, gpu_arch::Xe > |
| using | mem_desc_a_t = mem_desc_t< T, layout_input, mem_loc_input > |
| using | mem_desc_b_t = mem_desc_t< T, layout_weight, mem_loc_weight > |
| using | tile_shape = tile_shape_t< wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m > |
| using | gemm_op = gemm_t< compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t > |
| using | work_group_t = typename gemm_op::work_group_t |
| using | gemm_arguments = typename gemm_op::arguments_t |
| using | matAcc_t = typename gemm_op::matAcc_t |
| using | mem_desc_c_t = mem_desc_t< T, layout_out, mem_loc_out > |
| using | epilogue_t = epilogue_t< epilogue_policy_default< gpu_arch::Xe >, tile_shape, mem_desc_c_t > |
| using | epilogue_args_t = typename epilogue_t::arguments_t |
| using | matC_tile_desc_t = tile_desc_t< matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled > |
| using | mat_hidden_t = tile_t< T, matC_tile_desc_t > |
| using | matC_t = tile_t< T, matC_tile_desc_t > |
| using | mat_hidden_payload_t = mem_payload_t< mem_desc_a_t, matC_tile_desc_t, msg_type_v< matC_tile_desc_t, mem_loc_input >, gpu_arch::Xe > |
| using | matC_payload_t = mem_payload_t< mem_desc_c_t, matC_tile_desc_t, msg_type::block_2d, gpu_arch::Xe > |
| using | sigmoid_t = typename subgroup::sigmoid_op_t |
| using | tanh_t = typename subgroup::tanh_op_t |
Static Public Member Functions | |
| static void | call (sycl::nd_item< 3 > &item, fused_config_t< T > *args) |
Static Public Attributes | |
| static constexpr uint32_t | prefetch_distance = 3 |
| static constexpr bool | is_col_major_a = layout_input == mem_layout::col_major |
| static constexpr bool | is_col_major_b = layout_weight == mem_layout::col_major |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::compute_attr = group::compute_attr_t<T, T, Act_T> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::compute_policy = compute_policy_default_xmx<compute_attr, perf_tuning_knob, gpu_arch::Xe> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::epilogue_args_t = typename epilogue_t::arguments_t |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::epilogue_t = epilogue_t<epilogue_policy_default<gpu_arch::Xe>, tile_shape, mem_desc_c_t> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::gemm_arguments = typename gemm_op::arguments_t |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::gemm_op = gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mat_hidden_payload_t = mem_payload_t<mem_desc_a_t, matC_tile_desc_t, msg_type_v<matC_tile_desc_t, mem_loc_input>, gpu_arch::Xe> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mat_hidden_t = tile_t<T, matC_tile_desc_t> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matAcc_t = typename gemm_op::matAcc_t |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_payload_t = mem_payload_t<mem_desc_c_t, matC_tile_desc_t, msg_type::block_2d, gpu_arch::Xe> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_t = tile_t<T, matC_tile_desc_t> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_tile_desc_t = tile_desc_t<matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_a_t = mem_desc_t<T, layout_input, mem_loc_input> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_b_t = mem_desc_t<T, layout_weight, mem_loc_weight> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_c_t = mem_desc_t<T, layout_out, mem_loc_out> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::perf_tuning_knob = perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::sigmoid_t = typename subgroup::sigmoid_op_t |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::tanh_t = typename subgroup::tanh_op_t |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::tile_shape = tile_shape_t<wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m> |
| using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::work_group_t = typename gemm_op::work_group_t |
|
inlinestatic |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |