XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen > Struct Template Reference

#include <mha_attn_reg.hpp>

Classes

struct  arguments_t
 Arguments for xetla_softmax_fwd_t::run. More...
 

Public Types

using dtype_bin = dtype_bin_
 
using dtype_bot = dtype_bot_
 
using dtype_sfx = dtype_sfx_
 
using dtype_acc = dtype_acc_
 
using bgm_perf_tuning_knob = group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval >
 
using tile_attr_128x128 = group::tile_shape_t< 128, 128, 32, 16 >
 
using tile_attr_128x256 = group::tile_shape_t< 256, 128, 64, 16 >
 
using tile_attr_64x384 = group::tile_shape_t< 384, 64, 48, 16 >
 
using tile_attr_64x512 = group::tile_shape_t< 512, 64, 64, 16 >
 
using tile_attr_32x1024 = group::tile_shape_t< 1024, 32, 64, 16 >
 
using tile_attr_16x2048 = group::tile_shape_t< 2048, 16, 64, 16 >
 
using tile_attr_128x64 = group::tile_shape_t< 64, 128, 16, 16 >
 
using mem_desc_a_QKT = mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a >
 
using mem_desc_b_QKT = mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b >
 
using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe >
 
using mem_desc_a_out = mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a >
 
using mem_desc_b_out = mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b >
 
using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe >
 
using work_group_t = work_group_t< ThreadNum >
 
using pre_processing_128x128 = group::pre_processing_default_t< tile_attr_128x128, gpu_arch::Xe >
 
using pre_processing_128x256 = group::pre_processing_default_t< tile_attr_128x256, gpu_arch::Xe >
 
using pre_processing_64x384 = group::pre_processing_default_t< tile_attr_64x384, gpu_arch::Xe >
 
using pre_processing_64x512 = group::pre_processing_default_t< tile_attr_64x512, gpu_arch::Xe >
 
using pre_processing_32x1024 = group::pre_processing_default_t< tile_attr_32x1024, gpu_arch::Xe >
 
using pre_processing_16x2048 = group::pre_processing_default_t< tile_attr_16x2048, gpu_arch::Xe >
 
using pre_processing_128x64 = group::pre_processing_matA_neg_filter_t< tile_attr_128x64, gpu_arch::Xe >
 
using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128 >
 
using gemm_op_128x256_t = group::gemm_t< compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256 >
 
using gemm_op_64x384_t = group::gemm_t< compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384 >
 
using gemm_op_64x512_t = group::gemm_t< compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512 >
 
using gemm_op_32x1024_t = group::gemm_t< compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024 >
 
using gemm_op_16x2048_t = group::gemm_t< compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048 >
 
using gemm_op_128x64_t = group::gemm_t< compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64 >
 
using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t
 
using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t
 
using gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t
 
using gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t
 
using gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t
 
using gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t
 
using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t
 
using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t
 
using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t
 
using matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t
 
using matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t
 
using matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t
 
using matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t
 
using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t
 
using mat_128x128_tile_desc_t = subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_128x256_tile_desc_t = subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_64x384_tile_desc_t = subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_64x512_tile_desc_t = subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_32x1024_tile_desc_t = subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_16x2048_tile_desc_t = subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled >
 
using mat_128x64_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x128_t = subgroup::tile_t< dtype_sfx, mat_128x128_tile_desc_t >
 
using matC_128x256_t = subgroup::tile_t< dtype_sfx, mat_128x256_tile_desc_t >
 
using matC_64x384_t = subgroup::tile_t< dtype_sfx, mat_64x384_tile_desc_t >
 
using matC_64x512_t = subgroup::tile_t< dtype_sfx, mat_64x512_tile_desc_t >
 
using matC_32x1024_t = subgroup::tile_t< dtype_sfx, mat_32x1024_tile_desc_t >
 
using matC_16x2048_t = subgroup::tile_t< dtype_sfx, mat_16x2048_tile_desc_t >
 
using matC_128x64_t = subgroup::tile_t< dtype_sfx, mat_128x64_tile_desc_t >
 
using matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_128x128_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_128x128_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_128x256_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_128x256_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_64x384_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_64x384_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_64x512_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_64x512_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_32x1024_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_32x1024_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_16x2048_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_16x2048_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, mat_128x64_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< mat_128x64_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_128x128_t = subgroup::tile_t< uint8_t, mat_128x128_tile_desc_t >
 
using matDpotMk_128x256_t = subgroup::tile_t< uint8_t, mat_128x256_tile_desc_t >
 
using matDpotMk_64x384_t = subgroup::tile_t< uint8_t, mat_64x384_tile_desc_t >
 
using matDpotMk_64x512_t = subgroup::tile_t< uint8_t, mat_64x512_tile_desc_t >
 
using matDpotMk_32x1024_t = subgroup::tile_t< uint8_t, mat_32x1024_tile_desc_t >
 
using matDpotMk_16x2048_t = subgroup::tile_t< uint8_t, mat_16x2048_tile_desc_t >
 
using matDpotMk_128x64_t = subgroup::tile_t< uint8_t, mat_128x64_tile_desc_t >
 
using matDpotMk_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_128x128_tile_desc_t, subgroup::msg_type_v< mat_128x128_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_128x256_tile_desc_t, subgroup::msg_type_v< mat_128x256_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_64x384_tile_desc_t, subgroup::msg_type_v< mat_64x384_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_64x512_tile_desc_t, subgroup::msg_type_v< mat_64x512_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_32x1024_tile_desc_t, subgroup::msg_type_v< mat_32x1024_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_16x2048_tile_desc_t, subgroup::msg_type_v< mat_16x2048_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matDpotMk_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t< uint8_t, mem_layout_c, mem_space_c >, mat_128x64_tile_desc_t, subgroup::msg_type_v< mat_128x64_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 

Static Public Member Functions

static __XETLA_API void call (sycl::nd_item< 3 > &item, arguments_t *args)
 Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
 

Static Public Attributes

static constexpr int ThreadNum = HWThreadNum
 
static constexpr int max_seqlen = Max_SeqLen
 
static constexpr mem_space mem_space_a = mem_space::global
 
static constexpr mem_space mem_space_b = mem_space::global
 
static constexpr mem_space mem_space_c = mem_space::global
 
static constexpr uint16_t Rand_SIMD = RandSIMD
 
static constexpr mem_layout mem_layout_a = mem_layout::row_major
 
static constexpr mem_layout mem_layout_QKT_b = mem_layout::col_major
 
static constexpr mem_layout mem_layout_out_b = mem_layout::row_major
 
static constexpr mem_layout mem_layout_c = mem_layout::row_major
 
static constexpr mem_space gemm_mem_space_a = mem_space_a
 
static constexpr mem_layout gemm_mem_layout_a = mem_layout_a
 
static constexpr mem_space gemm_mem_space_b = mem_space_b
 
static constexpr mem_layout gemm_mem_layout_QKT_b = mem_layout_QKT_b
 
static constexpr mem_layout gemm_mem_layout_out_b = mem_layout_out_b
 
static constexpr uint32_t periodic_sync_interval = 0
 
static constexpr uint32_t prefetch_distance = 3
 
static constexpr uint32_t k_stride = 32 / sizeof(dtype_bin)
 
static constexpr uint32_t global_kslicing = 1
 
static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx)
 

Member Typedef Documentation

◆ bgm_perf_tuning_knob

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::bgm_perf_tuning_knob = group::perf_tuning_knob_t<k_stride, prefetch_distance, periodic_sync_interval>

◆ compute_policy_out

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe>

◆ compute_policy_QKT

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t<dtype_bin, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe>

◆ dtype_acc

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_acc = dtype_acc_

◆ dtype_bin

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_bin = dtype_bin_

◆ dtype_bot

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_bot = dtype_bot_

◆ dtype_sfx

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_sfx = dtype_sfx_

◆ gemm_arguments_128x128

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t

◆ gemm_arguments_128x256

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t

◆ gemm_arguments_128x64

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t

◆ gemm_arguments_16x2048

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t

◆ gemm_arguments_32x1024

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t

◆ gemm_arguments_64x384

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t

◆ gemm_arguments_64x512

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t

◆ gemm_op_128x128_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x128_t = group::gemm_t<compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128>

◆ gemm_op_128x256_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x256_t = group::gemm_t<compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256>

◆ gemm_op_128x64_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x64_t = group::gemm_t<compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64>

◆ gemm_op_16x2048_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_16x2048_t = group::gemm_t<compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048>

◆ gemm_op_32x1024_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_32x1024_t = group::gemm_t<compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024>

◆ gemm_op_64x384_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_64x384_t = group::gemm_t<compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384>

◆ gemm_op_64x512_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_64x512_t = group::gemm_t<compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512>

◆ mat_128x128_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x128_tile_desc_t = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_128x256_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x256_tile_desc_t = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_128x64_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x64_tile_desc_t = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_16x2048_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_16x2048_tile_desc_t = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_32x1024_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_32x1024_tile_desc_t = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_64x384_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_64x384_tile_desc_t = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled>

◆ mat_64x512_tile_desc_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_64x512_tile_desc_t = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matAcc_128x128_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t

◆ matAcc_128x256_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t

◆ matAcc_128x64_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t

◆ matAcc_16x2048_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t

◆ matAcc_32x1024_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t

◆ matAcc_64x384_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t

◆ matAcc_64x512_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t

◆ matC_128x128_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x128_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x128_t = subgroup::tile_t<dtype_sfx, mat_128x128_tile_desc_t>

◆ matC_128x256_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x256_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x256_t = subgroup::tile_t<dtype_sfx, mat_128x256_tile_desc_t>

◆ matC_128x64_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x64_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x64_t = subgroup::tile_t<dtype_sfx, mat_128x64_tile_desc_t>

◆ matC_16x2048_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_16x2048_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_16x2048_t = subgroup::tile_t<dtype_sfx, mat_16x2048_tile_desc_t>

◆ matC_32x1024_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_32x1024_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_32x1024_t = subgroup::tile_t<dtype_sfx, mat_32x1024_tile_desc_t>

◆ matC_64x384_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_64x384_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x384_t = subgroup::tile_t<dtype_sfx, mat_64x384_tile_desc_t>

◆ matC_64x512_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_64x512_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x512_t = subgroup::tile_t<dtype_sfx, mat_64x512_tile_desc_t>

◆ matDpotMk_128x128_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x128_tile_desc_t, subgroup::msg_type_v<mat_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_128x128_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x128_t = subgroup::tile_t<uint8_t, mat_128x128_tile_desc_t>

◆ matDpotMk_128x256_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x256_tile_desc_t, subgroup::msg_type_v<mat_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_128x256_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x256_t = subgroup::tile_t<uint8_t, mat_128x256_tile_desc_t>

◆ matDpotMk_128x64_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x64_tile_desc_t, subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_128x64_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x64_t = subgroup::tile_t<uint8_t, mat_128x64_tile_desc_t>

◆ matDpotMk_16x2048_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_16x2048_tile_desc_t, subgroup::msg_type_v<mat_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_16x2048_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_16x2048_t = subgroup::tile_t<uint8_t, mat_16x2048_tile_desc_t>

◆ matDpotMk_32x1024_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_32x1024_tile_desc_t, subgroup::msg_type_v<mat_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_32x1024_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_32x1024_t = subgroup::tile_t<uint8_t, mat_32x1024_tile_desc_t>

◆ matDpotMk_64x384_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_64x384_tile_desc_t, subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_64x384_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x384_t = subgroup::tile_t<uint8_t, mat_64x384_tile_desc_t>

◆ matDpotMk_64x512_payload_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_64x512_tile_desc_t, subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matDpotMk_64x512_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x512_t = subgroup::tile_t<uint8_t, mat_64x512_tile_desc_t>

◆ mem_desc_a_out

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_a_out = mem_desc_t<dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a>

◆ mem_desc_a_QKT

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_a_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_a, gemm_mem_space_a>

◆ mem_desc_b_out

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_b_out = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b>

◆ mem_desc_b_QKT

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_b_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b>

◆ pre_processing_128x128

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x128 = group::pre_processing_default_t<tile_attr_128x128, gpu_arch::Xe>

◆ pre_processing_128x256

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x256 = group::pre_processing_default_t<tile_attr_128x256, gpu_arch::Xe>

◆ pre_processing_128x64

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x64 = group::pre_processing_matA_neg_filter_t<tile_attr_128x64, gpu_arch::Xe>

◆ pre_processing_16x2048

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_16x2048 = group::pre_processing_default_t<tile_attr_16x2048, gpu_arch::Xe>

◆ pre_processing_32x1024

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_32x1024 = group::pre_processing_default_t<tile_attr_32x1024, gpu_arch::Xe>

◆ pre_processing_64x384

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_64x384 = group::pre_processing_default_t<tile_attr_64x384, gpu_arch::Xe>

◆ pre_processing_64x512

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_64x512 = group::pre_processing_default_t<tile_attr_64x512, gpu_arch::Xe>

◆ tile_attr_128x128

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x128 = group::tile_shape_t<128, 128, 32, 16>

◆ tile_attr_128x256

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x256 = group::tile_shape_t<256, 128, 64, 16>

◆ tile_attr_128x64

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x64 = group::tile_shape_t<64, 128, 16, 16>

◆ tile_attr_16x2048

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_16x2048 = group::tile_shape_t<2048, 16, 64, 16>

◆ tile_attr_32x1024

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_32x1024 = group::tile_shape_t<1024, 32, 64, 16>

◆ tile_attr_64x384

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_64x384 = group::tile_shape_t<384, 64, 48, 16>

◆ tile_attr_64x512

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_64x512 = group::tile_shape_t<512, 64, 64, 16>

◆ work_group_t

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::work_group_t = work_group_t<ThreadNum>

Member Function Documentation

◆ call()

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
static __XETLA_API void gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::call ( sycl::nd_item< 3 > &  item,
arguments_t args 
)
inlinestatic

Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.

Parameters
args[in] Includes base descriptors and tid info.

Member Data Documentation

◆ gemm_mem_layout_a

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_mem_layout_a = mem_layout_a
staticconstexpr

◆ gemm_mem_layout_out_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_mem_layout_out_b = mem_layout_out_b
staticconstexpr

◆ gemm_mem_layout_QKT_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_mem_layout_QKT_b = mem_layout_QKT_b
staticconstexpr

◆ gemm_mem_space_a

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_mem_space_a = mem_space_a
staticconstexpr

◆ gemm_mem_space_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_mem_space_b = mem_space_b
staticconstexpr

◆ global_kslicing

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::global_kslicing = 1
staticconstexpr

◆ k_stride

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::k_stride = 32 / sizeof(dtype_bin)
staticconstexpr

◆ max_seqlen

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr int gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::max_seqlen = Max_SeqLen
staticconstexpr

◆ mem_layout_a

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_layout_a = mem_layout::row_major
staticconstexpr

◆ mem_layout_c

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_layout_c = mem_layout::row_major
staticconstexpr

◆ mem_layout_out_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_layout_out_b = mem_layout::row_major
staticconstexpr

◆ mem_layout_QKT_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_layout_QKT_b = mem_layout::col_major
staticconstexpr

◆ mem_space_a

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_space_a = mem_space::global
staticconstexpr

◆ mem_space_b

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_space_b = mem_space::global
staticconstexpr

◆ mem_space_c

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_space_c = mem_space::global
staticconstexpr

◆ periodic_sync_interval

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::periodic_sync_interval = 0
staticconstexpr

◆ prefetch_distance

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::prefetch_distance = 3
staticconstexpr

◆ Rand_SIMD

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint16_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::Rand_SIMD = RandSIMD
staticconstexpr

◆ sfx_type_size

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr uint16_t gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::sfx_type_size = sizeof(dtype_sfx)
staticconstexpr

◆ ThreadNum

template<typename dtype_bin_ , typename dtype_bot_ , typename dtype_sfx_ , typename dtype_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, uint16_t RandSIMD = 16, int Max_SeqLen = 2048>
constexpr int gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::ThreadNum = HWThreadNum
staticconstexpr