XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen > Struct Template Reference

#include <mha_attn_reg.hpp>

Classes

struct  arguments_t
 Arguments for xetla_softmax_bwd_t::run. More...
 

Public Types

using dtype_bin = dtype_bwd_bin_
 
using dtype_bot = dtype_bwd_bot_
 
using dtype_sfx = dtype_bwd_sfx_
 
using dtype_acc = dtype_bwd_acc_
 
using bgm_perf_tuning_knob = group::perf_tuning_knob_t< k_stride, prefetch_distance, periodic_sync_interval >
 
using tile_attr_128x128 = group::tile_shape_t< 128, 128, 32, 16 >
 
using tile_attr_128x256 = group::tile_shape_t< 256, 128, 64, 16 >
 
using tile_attr_64x384 = group::tile_shape_t< 384, 64, 48, 16 >
 
using tile_attr_64x512 = group::tile_shape_t< 512, 64, 64, 16 >
 
using tile_attr_32x1024 = group::tile_shape_t< 1024, 32, 64, 16 >
 
using tile_attr_16x2048 = group::tile_shape_t< 2048, 16, 64, 16 >
 
using tile_attr_256x64 = group::tile_shape_t< 64, 256, 16, 32 >
 
using tile_attr_128x64 = group::tile_shape_t< 64, 128, 16, 16 >
 
using mem_desc_a_QKT = mem_desc_t< dtype_bin, gemm_mem_layout_a, gemm_mem_space_a >
 
using mem_desc_b_QKT = mem_desc_t< dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b >
 
using compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t< dtype_bin, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe >
 
using mem_desc_a_out = mem_desc_t< dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a >
 
using mem_desc_b_out = mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b >
 
using compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe >
 
using mem_desc_a_out_b_trnp_a = mem_desc_t< dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a >
 
using mem_desc_b_out_b_trnp_a = mem_desc_t< dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b >
 
using compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t< dtype_sfx, dtype_bin, dtype_acc >, bgm_perf_tuning_knob, gpu_arch::Xe >
 
using work_group_t = work_group_t< ThreadNum >
 
using pre_processing_128x128 = group::pre_processing_default_t< tile_attr_128x128, gpu_arch::Xe >
 
using pre_processing_128x256 = group::pre_processing_default_t< tile_attr_128x256, gpu_arch::Xe >
 
using pre_processing_64x384 = group::pre_processing_default_t< tile_attr_64x384, gpu_arch::Xe >
 
using pre_processing_64x512 = group::pre_processing_default_t< tile_attr_64x512, gpu_arch::Xe >
 
using pre_processing_32x1024 = group::pre_processing_default_t< tile_attr_32x1024, gpu_arch::Xe >
 
using pre_processing_16x2048 = group::pre_processing_default_t< tile_attr_16x2048, gpu_arch::Xe >
 
using pre_processing_128x64 = group::pre_processing_default_t< tile_attr_128x64, gpu_arch::Xe >
 
using pre_processing_256x64 = group::pre_processing_default_t< tile_attr_256x64, gpu_arch::Xe >
 
using pre_processing_128x64_af = group::pre_processing_matA_neg_filter_t< tile_attr_128x64, gpu_arch::Xe >
 
using pre_processing_256x64_af = group::pre_processing_matA_neg_filter_t< tile_attr_256x64, gpu_arch::Xe >
 
using gemm_op_128x128_t = group::gemm_t< compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128 >
 
using gemm_op_128x256_t = group::gemm_t< compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256 >
 
using gemm_op_64x384_t = group::gemm_t< compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384 >
 
using gemm_op_64x512_t = group::gemm_t< compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512 >
 
using gemm_op_32x1024_t = group::gemm_t< compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024 >
 
using gemm_op_16x2048_t = group::gemm_t< compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048 >
 
using gemm_op_128x64_t = group::gemm_t< compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64 >
 
using gemm_op_128x64_trnp_a_t = group::gemm_t< compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64 >
 
using gemm_op_256x64_trnp_a_t = group::gemm_t< compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64 >
 
using gemm_op_128x64_trnp_af_t = group::gemm_t< compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64_af >
 
using gemm_op_256x64_trnp_af_t = group::gemm_t< compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64_af >
 
using gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t
 
using gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t
 
using gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t
 
using gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t
 
using gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t
 
using gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t
 
using gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t
 
using gemm_arguments_128x64_trnp_a = typename gemm_op_128x64_trnp_a_t::arguments_t
 
using gemm_arguments_256x64_trnp_a = typename gemm_op_256x64_trnp_a_t::arguments_t
 
using gemm_arguments_128x64_trnp_af = typename gemm_op_128x64_trnp_af_t::arguments_t
 
using gemm_arguments_256x64_trnp_af = typename gemm_op_256x64_trnp_af_t::arguments_t
 
using matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t
 
using matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t
 
using matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t
 
using matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t
 
using matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t
 
using matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t
 
using matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t
 
using matAcc_128x64_trnp_a_t = typename gemm_op_128x64_trnp_a_t::matAcc_t
 
using matAcc_256x64_trnp_a_t = typename gemm_op_256x64_trnp_a_t::matAcc_t
 
using matAcc_128x64_trnp_af_t = typename gemm_op_128x64_trnp_af_t::matAcc_t
 
using matAcc_256x64_trnp_af_t = typename gemm_op_256x64_trnp_af_t::matAcc_t
 
using matC_128x128_tile_desc_t = subgroup::tile_desc_t< matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x256_tile_desc_t = subgroup::tile_desc_t< matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_64x384_tile_desc_t = subgroup::tile_desc_t< matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_64x512_tile_desc_t = subgroup::tile_desc_t< matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_32x1024_tile_desc_t = subgroup::tile_desc_t< matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_16x2048_tile_desc_t = subgroup::tile_desc_t< matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x128_t = subgroup::tile_t< dtype_sfx, matC_128x128_tile_desc_t >
 
using matC_128x256_t = subgroup::tile_t< dtype_sfx, matC_128x256_tile_desc_t >
 
using matC_64x384_t = subgroup::tile_t< dtype_sfx, matC_64x384_tile_desc_t >
 
using matC_64x512_t = subgroup::tile_t< dtype_sfx, matC_64x512_tile_desc_t >
 
using matC_32x1024_t = subgroup::tile_t< dtype_sfx, matC_32x1024_tile_desc_t >
 
using matC_16x2048_t = subgroup::tile_t< dtype_sfx, matC_16x2048_tile_desc_t >
 
using matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_128x128_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_128x128_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_128x256_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_128x256_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_64x384_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_64x384_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_64x512_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_64x512_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_32x1024_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_32x1024_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_16x2048_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_16x2048_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x64_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_256x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_256x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled >
 
using matC_128x64_t = subgroup::tile_t< dtype_bot, matC_128x64_tile_desc_t >
 
using matC_128x64_trnp_a_t = subgroup::tile_t< dtype_bot, matC_128x64_trnp_a_tile_desc_t >
 
using matC_256x64_trnp_a_t = subgroup::tile_t< dtype_bot, matC_256x64_trnp_a_tile_desc_t >
 
using matC_128x64_trnp_af_t = subgroup::tile_t< dtype_bot, matC_128x64_trnp_af_tile_desc_t >
 
using matC_256x64_trnp_af_t = subgroup::tile_t< dtype_bot, matC_256x64_trnp_af_tile_desc_t >
 
using matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_bot, mem_layout_c, mem_space_c >, matC_128x64_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_128x64_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_bot, mem_layout_c, mem_space_c >, matC_128x64_trnp_a_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_128x64_trnp_a_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_bot, mem_layout_c, mem_space_c >, matC_256x64_trnp_a_tile_desc_t, subgroup::msg_type_v< matC_256x64_trnp_a_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_bot, mem_layout_c, mem_space_c >, matC_128x64_trnp_af_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_128x64_trnp_af_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_bot, mem_layout_c, mem_space_c >, matC_256x64_trnp_af_tile_desc_t,(global_kslicing > 1) ? msg_type::atomic_add :subgroup::msg_type_v< matC_256x64_trnp_af_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_128x128_t = subgroup::tile_t< dtype_sfx, matC_128x128_tile_desc_t >
 
using matW_128x256_t = subgroup::tile_t< dtype_sfx, matC_128x256_tile_desc_t >
 
using matW_64x384_t = subgroup::tile_t< dtype_sfx, matC_64x384_tile_desc_t >
 
using matW_64x512_t = subgroup::tile_t< dtype_sfx, matC_64x512_tile_desc_t >
 
using matW_32x1024_t = subgroup::tile_t< dtype_sfx, matC_32x1024_tile_desc_t >
 
using matW_16x2048_t = subgroup::tile_t< dtype_sfx, matC_16x2048_tile_desc_t >
 
using matW_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_128x128_tile_desc_t, subgroup::msg_type_v< matC_128x128_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_128x256_tile_desc_t, subgroup::msg_type_v< matC_128x256_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_64x384_tile_desc_t, subgroup::msg_type_v< matC_64x384_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_64x512_tile_desc_t, subgroup::msg_type_v< matC_64x512_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_32x1024_tile_desc_t, subgroup::msg_type_v< matC_32x1024_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 
using matW_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t< dtype_sfx, mem_layout_c, mem_space_c >, matC_16x2048_tile_desc_t, subgroup::msg_type_v< matC_16x2048_tile_desc_t, mem_space_c >, gpu_arch::Xe >
 

Static Public Member Functions

static __XETLA_API void call (sycl::nd_item< 3 > &item, arguments_t *args)
 Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
 

Static Public Attributes

static constexpr int ThreadNum = HWThreadNum
 
static constexpr mem_space mem_space_a = mem_space::global
 
static constexpr mem_space mem_space_b = mem_space::global
 
static constexpr mem_space mem_space_c = mem_space::global
 
static constexpr mem_layout mem_layout_a = mem_layout::row_major
 
static constexpr mem_layout mem_layout_trnp_a = mem_layout::col_major
 
static constexpr mem_layout mem_layout_QKT_b = mem_layout::col_major
 
static constexpr mem_layout mem_layout_out_b = mem_layout::row_major
 
static constexpr mem_layout mem_layout_c = mem_layout::row_major
 
static constexpr mem_space gemm_mem_space_a = mem_space_a
 
static constexpr mem_space gemm_mem_space_trnp_a = mem_space_a
 
static constexpr mem_layout gemm_mem_layout_a = mem_layout_a
 
static constexpr mem_layout gemm_mem_layout_trnp_a = mem_layout_trnp_a
 
static constexpr mem_space gemm_mem_space_b = mem_space_b
 
static constexpr mem_layout gemm_mem_layout_QKT_b = mem_layout_QKT_b
 
static constexpr mem_layout gemm_mem_layout_out_b = mem_layout_out_b
 
static constexpr uint32_t periodic_sync_interval = 0
 
static constexpr uint32_t prefetch_distance = 3
 
static constexpr uint32_t k_stride = 32 / sizeof(dtype_bin)
 
static constexpr uint32_t global_kslicing = 1
 
static constexpr uint16_t sfx_type_size = sizeof(dtype_sfx)
 

Member Typedef Documentation

◆ bgm_perf_tuning_knob

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::bgm_perf_tuning_knob = group::perf_tuning_knob_t<k_stride, prefetch_distance, periodic_sync_interval>

◆ compute_policy_out

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe>

◆ compute_policy_out_b_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe>

◆ compute_policy_QKT

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t<dtype_bin, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe>

◆ dtype_acc

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_acc = dtype_bwd_acc_

◆ dtype_bin

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_bin = dtype_bwd_bin_

◆ dtype_bot

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_bot = dtype_bwd_bot_

◆ dtype_sfx

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_sfx = dtype_bwd_sfx_

◆ gemm_arguments_128x128

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t

◆ gemm_arguments_128x256

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t

◆ gemm_arguments_128x64

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t

◆ gemm_arguments_128x64_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64_trnp_a = typename gemm_op_128x64_trnp_a_t::arguments_t

◆ gemm_arguments_128x64_trnp_af

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64_trnp_af = typename gemm_op_128x64_trnp_af_t::arguments_t

◆ gemm_arguments_16x2048

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t

◆ gemm_arguments_256x64_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_256x64_trnp_a = typename gemm_op_256x64_trnp_a_t::arguments_t

◆ gemm_arguments_256x64_trnp_af

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_256x64_trnp_af = typename gemm_op_256x64_trnp_af_t::arguments_t

◆ gemm_arguments_32x1024

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t

◆ gemm_arguments_64x384

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t

◆ gemm_arguments_64x512

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t

◆ gemm_op_128x128_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x128_t = group::gemm_t<compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128>

◆ gemm_op_128x256_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x256_t = group::gemm_t<compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256>

◆ gemm_op_128x64_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_t = group::gemm_t<compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64>

◆ gemm_op_128x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_trnp_a_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64>

◆ gemm_op_128x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_trnp_af_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64_af>

◆ gemm_op_16x2048_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_16x2048_t = group::gemm_t<compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048>

◆ gemm_op_256x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_256x64_trnp_a_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64>

◆ gemm_op_256x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_256x64_trnp_af_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64_af>

◆ gemm_op_32x1024_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_32x1024_t = group::gemm_t<compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024>

◆ gemm_op_64x384_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_64x384_t = group::gemm_t<compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384>

◆ gemm_op_64x512_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_64x512_t = group::gemm_t<compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512>

◆ matAcc_128x128_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t

◆ matAcc_128x256_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t

◆ matAcc_128x64_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t

◆ matAcc_128x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_trnp_a_t = typename gemm_op_128x64_trnp_a_t::matAcc_t

◆ matAcc_128x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_trnp_af_t = typename gemm_op_128x64_trnp_af_t::matAcc_t

◆ matAcc_16x2048_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t

◆ matAcc_256x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_256x64_trnp_a_t = typename gemm_op_256x64_trnp_a_t::matAcc_t

◆ matAcc_256x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_256x64_trnp_af_t = typename gemm_op_256x64_trnp_af_t::matAcc_t

◆ matAcc_32x1024_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t

◆ matAcc_64x384_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t

◆ matAcc_64x512_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t

◆ matC_128x128_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x128_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_t = subgroup::tile_t<dtype_sfx, matC_128x128_tile_desc_t>

◆ matC_128x128_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_tile_desc_t = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_128x256_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x256_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_t = subgroup::tile_t<dtype_sfx, matC_128x256_tile_desc_t>

◆ matC_128x256_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_tile_desc_t = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_128x64_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x64_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_t = subgroup::tile_t<dtype_bot, matC_128x64_tile_desc_t>

◆ matC_128x64_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_tile_desc_t = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_128x64_trnp_a_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x64_trnp_a_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_t = subgroup::tile_t<dtype_bot, matC_128x64_trnp_a_tile_desc_t>

◆ matC_128x64_trnp_a_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_128x64_trnp_af_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x64_trnp_af_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_128x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_t = subgroup::tile_t<dtype_bot, matC_128x64_trnp_af_tile_desc_t>

◆ matC_128x64_trnp_af_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_16x2048_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_16x2048_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_t = subgroup::tile_t<dtype_sfx, matC_16x2048_tile_desc_t>

◆ matC_16x2048_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_tile_desc_t = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_256x64_trnp_a_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_256x64_trnp_a_tile_desc_t, subgroup::msg_type_v<matC_256x64_trnp_a_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_256x64_trnp_a_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_t = subgroup::tile_t<dtype_bot, matC_256x64_trnp_a_tile_desc_t>

◆ matC_256x64_trnp_a_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_256x64_trnp_af_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_256x64_trnp_af_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_256x64_trnp_af_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_t = subgroup::tile_t<dtype_bot, matC_256x64_trnp_af_tile_desc_t>

◆ matC_256x64_trnp_af_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_32x1024_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_32x1024_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_t = subgroup::tile_t<dtype_sfx, matC_32x1024_tile_desc_t>

◆ matC_32x1024_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_tile_desc_t = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_64x384_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_64x384_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_t = subgroup::tile_t<dtype_sfx, matC_64x384_tile_desc_t>

◆ matC_64x384_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_tile_desc_t = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matC_64x512_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matC_64x512_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_t = subgroup::tile_t<dtype_sfx, matC_64x512_tile_desc_t>

◆ matC_64x512_tile_desc_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_tile_desc_t = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled>

◆ matW_128x128_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x128_tile_desc_t, subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_128x128_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x128_t = subgroup::tile_t<dtype_sfx, matC_128x128_tile_desc_t>

◆ matW_128x256_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x256_tile_desc_t, subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_128x256_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x256_t = subgroup::tile_t<dtype_sfx, matC_128x256_tile_desc_t>

◆ matW_16x2048_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_16x2048_tile_desc_t, subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_16x2048_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_16x2048_t = subgroup::tile_t<dtype_sfx, matC_16x2048_tile_desc_t>

◆ matW_32x1024_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_32x1024_tile_desc_t, subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_32x1024_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_32x1024_t = subgroup::tile_t<dtype_sfx, matC_32x1024_tile_desc_t>

◆ matW_64x384_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x384_tile_desc_t, subgroup::msg_type_v<matC_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_64x384_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x384_t = subgroup::tile_t<dtype_sfx, matC_64x384_tile_desc_t>

◆ matW_64x512_payload_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x512_tile_desc_t, subgroup::msg_type_v<matC_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe>

◆ matW_64x512_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x512_t = subgroup::tile_t<dtype_sfx, matC_64x512_tile_desc_t>

◆ mem_desc_a_out

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_out = mem_desc_t<dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a>

◆ mem_desc_a_out_b_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_out_b_trnp_a = mem_desc_t<dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a>

◆ mem_desc_a_QKT

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_a, gemm_mem_space_a>

◆ mem_desc_b_out

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_out = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b>

◆ mem_desc_b_out_b_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_out_b_trnp_a = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b>

◆ mem_desc_b_QKT

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b>

◆ pre_processing_128x128

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x128 = group::pre_processing_default_t<tile_attr_128x128, gpu_arch::Xe>

◆ pre_processing_128x256

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x256 = group::pre_processing_default_t<tile_attr_128x256, gpu_arch::Xe>

◆ pre_processing_128x64

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x64 = group::pre_processing_default_t<tile_attr_128x64, gpu_arch::Xe>

◆ pre_processing_128x64_af

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x64_af = group::pre_processing_matA_neg_filter_t<tile_attr_128x64, gpu_arch::Xe>

◆ pre_processing_16x2048

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_16x2048 = group::pre_processing_default_t<tile_attr_16x2048, gpu_arch::Xe>

◆ pre_processing_256x64

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_256x64 = group::pre_processing_default_t<tile_attr_256x64, gpu_arch::Xe>

◆ pre_processing_256x64_af

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_256x64_af = group::pre_processing_matA_neg_filter_t<tile_attr_256x64, gpu_arch::Xe>

◆ pre_processing_32x1024

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_32x1024 = group::pre_processing_default_t<tile_attr_32x1024, gpu_arch::Xe>

◆ pre_processing_64x384

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_64x384 = group::pre_processing_default_t<tile_attr_64x384, gpu_arch::Xe>

◆ pre_processing_64x512

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_64x512 = group::pre_processing_default_t<tile_attr_64x512, gpu_arch::Xe>

◆ tile_attr_128x128

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x128 = group::tile_shape_t<128, 128, 32, 16>

◆ tile_attr_128x256

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x256 = group::tile_shape_t<256, 128, 64, 16>

◆ tile_attr_128x64

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x64 = group::tile_shape_t<64, 128, 16, 16>

◆ tile_attr_16x2048

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_16x2048 = group::tile_shape_t<2048, 16, 64, 16>

◆ tile_attr_256x64

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_256x64 = group::tile_shape_t<64, 256, 16, 32>

◆ tile_attr_32x1024

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_32x1024 = group::tile_shape_t<1024, 32, 64, 16>

◆ tile_attr_64x384

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_64x384 = group::tile_shape_t<384, 64, 48, 16>

◆ tile_attr_64x512

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_64x512 = group::tile_shape_t<512, 64, 64, 16>

◆ work_group_t

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::work_group_t = work_group_t<ThreadNum>

Member Function Documentation

◆ call()

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
static __XETLA_API void gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::call ( sycl::nd_item< 3 > &  item,
arguments_t args 
)
inlinestatic

Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.

Parameters
args[in] Includes base descriptors and tid info.

Member Data Documentation

◆ gemm_mem_layout_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_layout_a = mem_layout_a
staticconstexpr

◆ gemm_mem_layout_out_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_layout_out_b = mem_layout_out_b
staticconstexpr

◆ gemm_mem_layout_QKT_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_layout_QKT_b = mem_layout_QKT_b
staticconstexpr

◆ gemm_mem_layout_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_layout_trnp_a = mem_layout_trnp_a
staticconstexpr

◆ gemm_mem_space_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_space_a = mem_space_a
staticconstexpr

◆ gemm_mem_space_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_space_b = mem_space_b
staticconstexpr

◆ gemm_mem_space_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_mem_space_trnp_a = mem_space_a
staticconstexpr

◆ global_kslicing

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::global_kslicing = 1
staticconstexpr

◆ k_stride

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::k_stride = 32 / sizeof(dtype_bin)
staticconstexpr

◆ mem_layout_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_layout_a = mem_layout::row_major
staticconstexpr

◆ mem_layout_c

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_layout_c = mem_layout::row_major
staticconstexpr

◆ mem_layout_out_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_layout_out_b = mem_layout::row_major
staticconstexpr

◆ mem_layout_QKT_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_layout_QKT_b = mem_layout::col_major
staticconstexpr

◆ mem_layout_trnp_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_layout gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_layout_trnp_a = mem_layout::col_major
staticconstexpr

◆ mem_space_a

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_space_a = mem_space::global
staticconstexpr

◆ mem_space_b

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_space_b = mem_space::global
staticconstexpr

◆ mem_space_c

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr mem_space gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_space_c = mem_space::global
staticconstexpr

◆ periodic_sync_interval

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::periodic_sync_interval = 0
staticconstexpr

◆ prefetch_distance

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr uint32_t gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::prefetch_distance = 3
staticconstexpr

◆ sfx_type_size

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr uint16_t gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::sfx_type_size = sizeof(dtype_sfx)
staticconstexpr

◆ ThreadNum

template<typename dtype_bwd_bin_ , typename dtype_bwd_bot_ , typename dtype_bwd_sfx_ , typename dtype_bwd_acc_ , int HWThreadNum, bool Dopt_RandGenflag = true, bool Mkin_flag = false, int Max_SeqLen = 512>
constexpr int gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::ThreadNum = HWThreadNum
staticconstexpr