XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
kernel_xcoder_gru_fusion< input_T, Act_T, wg_tile_m_t, wg_tile_n_t, sg_tile_m_t, sg_tile_n_t, sg_tile_k_t > Struct Template Reference

#include <kernel_func.hpp>

Static Public Member Functions

static void run (sycl::nd_item< 3 > &item, input_T *layer_ptr, input_T *h0_ptr, input_T *W_ir_ptr, input_T *W_hr_ptr, input_T *W_iz_ptr, input_T *W_hz_ptr, input_T *W_in_ptr, input_T *W_hn_ptr, input_T *layer_out_ptr, input_T *hidden_out_ptr, input_T *ping_pong_buffer, input_T *ping_pong_cell, int batch_size, int input_size, int hidden_size, int sequence_length, int layer_size)
 

Member Function Documentation

◆ run()

template<typename input_T , typename Act_T , uint32_t wg_tile_m_t, uint32_t wg_tile_n_t, uint32_t sg_tile_m_t, uint32_t sg_tile_n_t, uint32_t sg_tile_k_t>
static void kernel_xcoder_gru_fusion< input_T, Act_T, wg_tile_m_t, wg_tile_n_t, sg_tile_m_t, sg_tile_n_t, sg_tile_k_t >::run ( sycl::nd_item< 3 > &  item,
input_T *  layer_ptr,
input_T *  h0_ptr,
input_T *  W_ir_ptr,
input_T *  W_hr_ptr,
input_T *  W_iz_ptr,
input_T *  W_hz_ptr,
input_T *  W_in_ptr,
input_T *  W_hn_ptr,
input_T *  layer_out_ptr,
input_T *  hidden_out_ptr,
input_T *  ping_pong_buffer,
input_T *  ping_pong_cell,
int  batch_size,
int  input_size,
int  hidden_size,
int  sequence_length,
int  layer_size 
)
inlinestatic
Parameters
itemIs the sycl::nd_item
layer_ptrinput from previous layer i.e X_t
h0_ptrhx_ptr input i.e. h_{0} shape = layer_size x batch_size x hidden_size weights
W_ir_ptrweights with input of reset gate, (input_weight_size, hidden_weight_size, ...)
W_hr_ptrweights with hidden input of reset gate, shape = layer_size x hidden_weight_size
W_iz_ptrweights with input of input gate, (input_weight_size, hidden_weight_size, ...)
W_hz_ptrweights with hidden input of input gate, shape = layer_size x hidden_weight_size
W_in_ptrweights with input of new gate, (input_weight_size, hidden_weight_size, ...)
W_hn_ptrweights with hidden input of new gate, shape = layer_size x hidden_weight_size output
layer_out_ptrthe last cell per layer output, shape = layer_size x batch_size x hidden_size
hidden_out_ptrthe last layer output for per gru cell, shape = sequence_length x batch_size x hidden_size