#include <mha_attn_reg.hpp>
Classes | |
| struct | arguments_t |
| Arguments for xetla_softmax_bwd_t::run. More... | |
Static Public Member Functions | |
| static __XETLA_API void | call (sycl::nd_item< 3 > &item, arguments_t *args) |
| Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM. | |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::bgm_perf_tuning_knob = group::perf_tuning_knob_t<k_stride, prefetch_distance, periodic_sync_interval> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_out_b_trnp_a = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t<dtype_bin, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_acc = dtype_bwd_acc_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_bin = dtype_bwd_bin_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_bot = dtype_bwd_bot_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::dtype_sfx = dtype_bwd_sfx_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64_trnp_a = typename gemm_op_128x64_trnp_a_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_128x64_trnp_af = typename gemm_op_128x64_trnp_af_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_256x64_trnp_a = typename gemm_op_256x64_trnp_a_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_256x64_trnp_af = typename gemm_op_256x64_trnp_af_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x128_t = group::gemm_t<compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x256_t = group::gemm_t<compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_t = group::gemm_t<compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_trnp_a_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_128x64_trnp_af_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_128x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_128x64_af> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_16x2048_t = group::gemm_t<compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_256x64_trnp_a_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_256x64_trnp_af_t = group::gemm_t<compute_policy_out_b_trnp_a, tile_attr_256x64, mem_desc_a_out_b_trnp_a, mem_desc_b_out_b_trnp_a, pre_processing_256x64_af> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_32x1024_t = group::gemm_t<compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_64x384_t = group::gemm_t<compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::gemm_op_64x512_t = group::gemm_t<compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_trnp_a_t = typename gemm_op_128x64_trnp_a_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_128x64_trnp_af_t = typename gemm_op_128x64_trnp_af_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_256x64_trnp_a_t = typename gemm_op_256x64_trnp_a_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_256x64_trnp_af_t = typename gemm_op_256x64_trnp_af_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_t = subgroup::tile_t<dtype_sfx, matC_128x128_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x128_tile_desc_t = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_t = subgroup::tile_t<dtype_sfx, matC_128x256_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x256_tile_desc_t = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_t = subgroup::tile_t<dtype_bot, matC_128x64_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_tile_desc_t = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x64_trnp_a_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_t = subgroup::tile_t<dtype_bot, matC_128x64_trnp_a_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_a_t::tile_desc::tile_size_x, matAcc_128x64_trnp_a_t::tile_desc::tile_size_y, matAcc_128x64_trnp_a_t::tile_desc::block_size_x, matAcc_128x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_128x64_trnp_af_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_t = subgroup::tile_t<dtype_bot, matC_128x64_trnp_af_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_128x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_128x64_trnp_af_t::tile_desc::tile_size_x, matAcc_128x64_trnp_af_t::tile_desc::tile_size_y, matAcc_128x64_trnp_af_t::tile_desc::block_size_x, matAcc_128x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_t = subgroup::tile_t<dtype_sfx, matC_16x2048_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_16x2048_tile_desc_t = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_256x64_trnp_a_tile_desc_t, subgroup::msg_type_v<matC_256x64_trnp_a_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_t = subgroup::tile_t<dtype_bot, matC_256x64_trnp_a_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_a_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_a_t::tile_desc::tile_size_x, matAcc_256x64_trnp_a_t::tile_desc::tile_size_y, matAcc_256x64_trnp_a_t::tile_desc::block_size_x, matAcc_256x64_trnp_a_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_bot, mem_layout_c, mem_space_c>, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_256x64_trnp_af_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_t = subgroup::tile_t<dtype_bot, matC_256x64_trnp_af_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_256x64_trnp_af_tile_desc_t = subgroup::tile_desc_t< matAcc_256x64_trnp_af_t::tile_desc::tile_size_x, matAcc_256x64_trnp_af_t::tile_desc::tile_size_y, matAcc_256x64_trnp_af_t::tile_desc::block_size_x, matAcc_256x64_trnp_af_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_t = subgroup::tile_t<dtype_sfx, matC_32x1024_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_32x1024_tile_desc_t = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_t = subgroup::tile_t<dtype_sfx, matC_64x384_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x384_tile_desc_t = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< matC_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_t = subgroup::tile_t<dtype_sfx, matC_64x512_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matC_64x512_tile_desc_t = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x128_tile_desc_t, subgroup::msg_type_v<matC_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x128_t = subgroup::tile_t<dtype_sfx, matC_128x128_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_128x256_tile_desc_t, subgroup::msg_type_v<matC_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_128x256_t = subgroup::tile_t<dtype_sfx, matC_128x256_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_16x2048_tile_desc_t, subgroup::msg_type_v<matC_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_16x2048_t = subgroup::tile_t<dtype_sfx, matC_16x2048_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_32x1024_tile_desc_t, subgroup::msg_type_v<matC_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_32x1024_t = subgroup::tile_t<dtype_sfx, matC_32x1024_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x384_tile_desc_t, subgroup::msg_type_v<matC_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x384_t = subgroup::tile_t<dtype_sfx, matC_64x384_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, matC_64x512_tile_desc_t, subgroup::msg_type_v<matC_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::matW_64x512_t = subgroup::tile_t<dtype_sfx, matC_64x512_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_out = mem_desc_t<dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_out_b_trnp_a = mem_desc_t<dtype_sfx, gemm_mem_layout_trnp_a, gemm_mem_space_trnp_a> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_a_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_a, gemm_mem_space_a> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_out = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_out_b_trnp_a = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::mem_desc_b_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x128 = group::pre_processing_default_t<tile_attr_128x128, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x256 = group::pre_processing_default_t<tile_attr_128x256, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x64 = group::pre_processing_default_t<tile_attr_128x64, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_128x64_af = group::pre_processing_matA_neg_filter_t<tile_attr_128x64, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_16x2048 = group::pre_processing_default_t<tile_attr_16x2048, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_256x64 = group::pre_processing_default_t<tile_attr_256x64, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_256x64_af = group::pre_processing_matA_neg_filter_t<tile_attr_256x64, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_32x1024 = group::pre_processing_default_t<tile_attr_32x1024, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_64x384 = group::pre_processing_default_t<tile_attr_64x384, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::pre_processing_64x512 = group::pre_processing_default_t<tile_attr_64x512, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x128 = group::tile_shape_t<128, 128, 32, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x256 = group::tile_shape_t<256, 128, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_128x64 = group::tile_shape_t<64, 128, 16, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_16x2048 = group::tile_shape_t<2048, 16, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_256x64 = group::tile_shape_t<64, 256, 16, 32> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_32x1024 = group::tile_shape_t<1024, 32, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_64x384 = group::tile_shape_t<384, 64, 48, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::tile_attr_64x512 = group::tile_shape_t<512, 64, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_bwd_t< dtype_bwd_bin_, dtype_bwd_bot_, dtype_bwd_sfx_, dtype_bwd_acc_, HWThreadNum, Dopt_RandGenflag, Mkin_flag, Max_SeqLen >::work_group_t = work_group_t<ThreadNum> |
|
inlinestatic |
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
| args | [in] Includes base descriptors and tid info. |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |