XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval > Struct Template Reference

#include <kernel_func.hpp>

Public Types

using perf_tuning_knob = perf_tuning_knob_t< sg_tile_k, prefetch_distance, periodic_sync_interval >
 
using compute_attr = group::compute_attr_t< T, T, Act_T >
 
using compute_policy = compute_policy_default_xmx< compute_attr, perf_tuning_knob, gpu_arch::Xe >
 
using mem_desc_a_t = mem_desc_t< T, layout_input, mem_loc_input >
 
using mem_desc_b_t = mem_desc_t< T, layout_weight, mem_loc_weight >
 
using tile_shape = tile_shape_t< wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m >
 
using gemm_op = gemm_t< compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t >
 
using work_group_t = typename gemm_op::work_group_t
 
using gemm_arguments = typename gemm_op::arguments_t
 
using matAcc_t = typename gemm_op::matAcc_t
 
using mem_desc_c_t = mem_desc_t< T, layout_out, mem_loc_out >
 
using epilogue_t = epilogue_t< epilogue_policy_default< gpu_arch::Xe >, tile_shape, mem_desc_c_t >
 
using epilogue_args_t = typename epilogue_t::arguments_t
 
using matC_tile_desc_t = tile_desc_t< matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled >
 
using mat_hidden_t = tile_t< T, matC_tile_desc_t >
 
using matC_t = tile_t< T, matC_tile_desc_t >
 
using mat_hidden_payload_t = mem_payload_t< mem_desc_a_t, matC_tile_desc_t, msg_type_v< matC_tile_desc_t, mem_loc_input >, gpu_arch::Xe >
 
using matC_payload_t = mem_payload_t< mem_desc_c_t, matC_tile_desc_t, msg_type::block_2d, gpu_arch::Xe >
 
using sigmoid_t = typename subgroup::sigmoid_op_t
 
using tanh_t = typename subgroup::tanh_op_t
 

Static Public Member Functions

static void call (sycl::nd_item< 3 > &item, fused_config_t< T > *args)
 

Static Public Attributes

static constexpr uint32_t prefetch_distance = 3
 
static constexpr bool is_col_major_a = layout_input == mem_layout::col_major
 
static constexpr bool is_col_major_b = layout_weight == mem_layout::col_major
 

Member Typedef Documentation

◆ compute_attr

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::compute_attr = group::compute_attr_t<T, T, Act_T>

◆ compute_policy

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::compute_policy = compute_policy_default_xmx<compute_attr, perf_tuning_knob, gpu_arch::Xe>

◆ epilogue_args_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::epilogue_args_t = typename epilogue_t::arguments_t

◆ epilogue_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::epilogue_t = epilogue_t<epilogue_policy_default<gpu_arch::Xe>, tile_shape, mem_desc_c_t>

◆ gemm_arguments

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::gemm_arguments = typename gemm_op::arguments_t

◆ gemm_op

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::gemm_op = gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t>

◆ mat_hidden_payload_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mat_hidden_payload_t = mem_payload_t<mem_desc_a_t, matC_tile_desc_t, msg_type_v<matC_tile_desc_t, mem_loc_input>, gpu_arch::Xe>

◆ mat_hidden_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mat_hidden_t = tile_t<T, matC_tile_desc_t>

◆ matAcc_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matAcc_t = typename gemm_op::matAcc_t

◆ matC_payload_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_payload_t = mem_payload_t<mem_desc_c_t, matC_tile_desc_t, msg_type::block_2d, gpu_arch::Xe>

◆ matC_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_t = tile_t<T, matC_tile_desc_t>

◆ matC_tile_desc_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::matC_tile_desc_t = tile_desc_t<matAcc_t::tile_size_x, matAcc_t::tile_size_y, matAcc_t::block_size_x, matAcc_t::block_size_y, reg_layout::tiled>

◆ mem_desc_a_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_a_t = mem_desc_t<T, layout_input, mem_loc_input>

◆ mem_desc_b_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_b_t = mem_desc_t<T, layout_weight, mem_loc_weight>

◆ mem_desc_c_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::mem_desc_c_t = mem_desc_t<T, layout_out, mem_loc_out>

◆ perf_tuning_knob

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::perf_tuning_knob = perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>

◆ sigmoid_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::sigmoid_t = typename subgroup::sigmoid_op_t

◆ tanh_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::tanh_t = typename subgroup::tanh_op_t

◆ tile_shape

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::tile_shape = tile_shape_t<wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m>

◆ work_group_t

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
using gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::work_group_t = typename gemm_op::work_group_t

Member Function Documentation

◆ call()

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
static void gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::call ( sycl::nd_item< 3 > &  item,
fused_config_t< T > *  args 
)
inlinestatic

Member Data Documentation

◆ is_col_major_a

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
constexpr bool gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::is_col_major_a = layout_input == mem_layout::col_major
staticconstexpr

◆ is_col_major_b

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
constexpr bool gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::is_col_major_b = layout_weight == mem_layout::col_major
staticconstexpr

◆ prefetch_distance

template<typename T , typename Act_T , uint32_t wg_tile_m, uint32_t wg_tile_n, uint32_t sg_tile_m, uint32_t sg_tile_n, uint32_t sg_tile_k, mem_layout layout_input = mem_layout::row_major, mem_layout layout_weight = mem_layout::row_major, mem_layout layout_out = mem_layout::row_major, mem_space mem_loc_input = mem_space::global, mem_space mem_loc_weight = mem_space::global, mem_space mem_loc_out = mem_space::global, uint32_t periodic_sync_interval = 0>
constexpr uint32_t gru_layer< T, Act_T, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, layout_input, layout_weight, layout_out, mem_loc_input, mem_loc_weight, mem_loc_out, periodic_sync_interval >::prefetch_distance = 3
staticconstexpr