23#include "group/gemm/compute_policy.hpp"
28template <
typename dtype_a,
typename dtype_b,
int alignment_a,
int alignment_b,
32 arch_tag>::template load_store_attr<msg_type::block_2d>;
33 static constexpr int alignment_bytes = load_store_attr::alignment_in_bytes;
34 static constexpr int alignment_bytes_a = alignment_a *
sizeof(dtype_a);
35 static constexpr int alignment_bytes_b = alignment_b *
sizeof(dtype_b);
38 static constexpr bool value = (alignment_bytes_a % alignment_bytes == 0)
39 && (alignment_bytes_b % alignment_bytes == 0);
48template <
typename dtype_a,
typename dtype_b,
mem_layout mem_layout_a,
50 int alignment_a,
int alignment_b,
typename dtype_acc,
51 typename tile_shape,
int k_stride,
gpu_arch arch_tag,
int stages,
54 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
56 std::enable_if_t<detail::check_2d_block_pitch_alignment<dtype_a,
57 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
74template <
typename dtype_a,
typename dtype_b,
mem_layout mem_layout_a,
76 int alignment_a,
int alignment_b,
typename dtype_acc,
77 typename tile_shape,
int k_stride,
gpu_arch arch_tag,
int stages,
80 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
82 std::enable_if_t<!detail::check_2d_block_pitch_alignment<dtype_a,
83 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
100template <
typename dtype_a,
typename dtype_b,
mem_layout mem_layout_a,
102 int alignment_a,
int alignment_b,
typename dtype_acc,
103 typename tile_shape,
int k_stride,
gpu_arch arch_tag,
int stages,
106 mem_space_b, alignment_a, alignment_b, dtype_acc, tile_shape, k_stride,
108 std::enable_if_t<detail::check_2d_block_pitch_alignment<dtype_a,
109 dtype_b, alignment_a, alignment_b, arch_tag>::value>> {
110 static_assert(std::is_same<dtype_a, dtype_acc>::value
111 && std::is_same<dtype_b, dtype_acc>::value,
112 "When use gemm_selector, dtype_a and dtype_b in fpu based gemm"
113 "should be the same as dtype_acc");
Definition selector_xe.hpp:30
static constexpr bool value
Definition selector_xe.hpp:38
Gemm selection functor.
Definition api.hpp:75
Gemm functor.
Definition api.hpp:52
Definition limitation.hpp:607
mem_space
Definition common.hpp:77
mma_engine
Definition common.hpp:225
gpu_arch
Definition common.hpp:73
mem_layout
Definition common.hpp:76
Definition arch_config.hpp:72
Compute attribute for gemm.
Definition common.hpp:32
Compute policy for fpu engine.
Definition compute_policy.hpp:105
Compute policy for xmx engine.
Definition compute_policy.hpp:35
Compute policy for unaligned shape and xmx engine.
Definition compute_policy.hpp:70
Fine-tune knobs for gemm.
Definition common.hpp:43
Gemm default pre_processing functor.
Definition api.hpp:33
Definition memory_descriptor.hpp:139