#include <mha_attn_reg.hpp>
Classes | |
| struct | arguments_t |
| Arguments for xetla_softmax_fwd_t::run. More... | |
Static Public Member Functions | |
| static __XETLA_API void | call (sycl::nd_item< 3 > &item, arguments_t *args) |
| Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM. | |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::bgm_perf_tuning_knob = group::perf_tuning_knob_t<k_stride, prefetch_distance, periodic_sync_interval> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::compute_policy_out = group::compute_policy_default_xmx< group::compute_attr_t<dtype_sfx, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::compute_policy_QKT = group::compute_policy_default_xmx< group::compute_attr_t<dtype_bin, dtype_bin, dtype_acc>, bgm_perf_tuning_knob, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_acc = dtype_acc_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_bin = dtype_bin_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_bot = dtype_bot_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::dtype_sfx = dtype_sfx_ |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x128 = typename gemm_op_128x128_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x256 = typename gemm_op_128x256_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_128x64 = typename gemm_op_128x64_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_16x2048 = typename gemm_op_16x2048_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_32x1024 = typename gemm_op_32x1024_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_64x384 = typename gemm_op_64x384_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_arguments_64x512 = typename gemm_op_64x512_t::arguments_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x128_t = group::gemm_t<compute_policy_QKT, tile_attr_128x128, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x128> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x256_t = group::gemm_t<compute_policy_QKT, tile_attr_128x256, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_128x256> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_128x64_t = group::gemm_t<compute_policy_out, tile_attr_128x64, mem_desc_a_out, mem_desc_b_out, pre_processing_128x64> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_16x2048_t = group::gemm_t<compute_policy_QKT, tile_attr_16x2048, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_16x2048> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_32x1024_t = group::gemm_t<compute_policy_QKT, tile_attr_32x1024, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_32x1024> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_64x384_t = group::gemm_t<compute_policy_QKT, tile_attr_64x384, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x384> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::gemm_op_64x512_t = group::gemm_t<compute_policy_QKT, tile_attr_64x512, mem_desc_a_QKT, mem_desc_b_QKT, pre_processing_64x512> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x128_tile_desc_t = subgroup::tile_desc_t<matAcc_128x128_t::tile_desc::tile_size_x, matAcc_128x128_t::tile_desc::tile_size_y, matAcc_128x128_t::tile_desc::block_size_x, matAcc_128x128_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x256_tile_desc_t = subgroup::tile_desc_t<matAcc_128x256_t::tile_desc::tile_size_x, matAcc_128x256_t::tile_desc::tile_size_y, matAcc_128x256_t::tile_desc::block_size_x, matAcc_128x256_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_128x64_tile_desc_t = subgroup::tile_desc_t<matAcc_128x64_t::tile_desc::tile_size_x, matAcc_128x64_t::tile_desc::tile_size_y, matAcc_128x64_t::tile_desc::block_size_x, matAcc_128x64_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_16x2048_tile_desc_t = subgroup::tile_desc_t<matAcc_16x2048_t::tile_desc::tile_size_x, matAcc_16x2048_t::tile_desc::tile_size_y, matAcc_16x2048_t::tile_desc::block_size_x, matAcc_16x2048_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_32x1024_tile_desc_t = subgroup::tile_desc_t<matAcc_32x1024_t::tile_desc::tile_size_x, matAcc_32x1024_t::tile_desc::tile_size_y, matAcc_32x1024_t::tile_desc::block_size_x, matAcc_32x1024_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_64x384_tile_desc_t = subgroup::tile_desc_t<matAcc_64x384_t::tile_desc::tile_size_x, matAcc_64x384_t::tile_desc::tile_size_y, matAcc_64x384_t::tile_desc::block_size_x, matAcc_64x384_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mat_64x512_tile_desc_t = subgroup::tile_desc_t<matAcc_64x512_t::tile_desc::tile_size_x, matAcc_64x512_t::tile_desc::tile_size_y, matAcc_64x512_t::tile_desc::block_size_x, matAcc_64x512_t::tile_desc::block_size_y, reg_layout::tiled> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x128_t = typename gemm_op_128x128_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x256_t = typename gemm_op_128x256_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_128x64_t = typename gemm_op_128x64_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_16x2048_t = typename gemm_op_16x2048_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_32x1024_t = typename gemm_op_32x1024_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_64x384_t = typename gemm_op_64x384_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matAcc_64x512_t = typename gemm_op_64x512_t::matAcc_t |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x128_t = subgroup::tile_t<dtype_sfx, mat_128x128_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x256_t = subgroup::tile_t<dtype_sfx, mat_128x256_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_128x64_t = subgroup::tile_t<dtype_sfx, mat_128x64_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_16x2048_t = subgroup::tile_t<dtype_sfx, mat_16x2048_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v< mat_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_32x1024_t = subgroup::tile_t<dtype_sfx, mat_32x1024_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x384_t = subgroup::tile_t<dtype_sfx, mat_64x384_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<dtype_sfx, mem_layout_c, mem_space_c>, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add : subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matC_64x512_t = subgroup::tile_t<dtype_sfx, mat_64x512_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x128_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x128_tile_desc_t, subgroup::msg_type_v<mat_128x128_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x128_t = subgroup::tile_t<uint8_t, mat_128x128_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x256_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x256_tile_desc_t, subgroup::msg_type_v<mat_128x256_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x256_t = subgroup::tile_t<uint8_t, mat_128x256_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x64_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_128x64_tile_desc_t, subgroup::msg_type_v<mat_128x64_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_128x64_t = subgroup::tile_t<uint8_t, mat_128x64_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_16x2048_tile_desc_t, subgroup::msg_type_v<mat_16x2048_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_16x2048_t = subgroup::tile_t<uint8_t, mat_16x2048_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_32x1024_tile_desc_t, subgroup::msg_type_v<mat_32x1024_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_32x1024_t = subgroup::tile_t<uint8_t, mat_32x1024_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x384_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_64x384_tile_desc_t, subgroup::msg_type_v<mat_64x384_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x384_t = subgroup::tile_t<uint8_t, mat_64x384_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x512_payload_t = subgroup::mem_payload_t< mem_desc_t<uint8_t, mem_layout_c, mem_space_c>, mat_64x512_tile_desc_t, subgroup::msg_type_v<mat_64x512_tile_desc_t, mem_space_c>, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::matDpotMk_64x512_t = subgroup::tile_t<uint8_t, mat_64x512_tile_desc_t> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_a_out = mem_desc_t<dtype_sfx, gemm_mem_layout_a, gemm_mem_space_a> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_a_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_a, gemm_mem_space_a> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_b_out = mem_desc_t<dtype_bin, gemm_mem_layout_out_b, gemm_mem_space_b> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::mem_desc_b_QKT = mem_desc_t<dtype_bin, gemm_mem_layout_QKT_b, gemm_mem_space_b> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x128 = group::pre_processing_default_t<tile_attr_128x128, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x256 = group::pre_processing_default_t<tile_attr_128x256, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_128x64 = group::pre_processing_matA_neg_filter_t<tile_attr_128x64, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_16x2048 = group::pre_processing_default_t<tile_attr_16x2048, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_32x1024 = group::pre_processing_default_t<tile_attr_32x1024, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_64x384 = group::pre_processing_default_t<tile_attr_64x384, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::pre_processing_64x512 = group::pre_processing_default_t<tile_attr_64x512, gpu_arch::Xe> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x128 = group::tile_shape_t<128, 128, 32, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x256 = group::tile_shape_t<256, 128, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_128x64 = group::tile_shape_t<64, 128, 16, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_16x2048 = group::tile_shape_t<2048, 16, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_32x1024 = group::tile_shape_t<1024, 32, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_64x384 = group::tile_shape_t<384, 64, 48, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::tile_attr_64x512 = group::tile_shape_t<512, 64, 64, 16> |
| using gpu::xetla::kernel::xetla_mha_attn_reg_fwd_t< dtype_bin_, dtype_bot_, dtype_sfx_, dtype_acc_, HWThreadNum, Dopt_RandGenflag, RandSIMD, Max_SeqLen >::work_group_t = work_group_t<ThreadNum> |
|
inlinestatic |
Main execution function for fused mha softmax The basic process is GEMM -> Softmax -> GEMM.
| args | [in] Includes base descriptors and tid info. |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |
|
staticconstexpr |