26template <
typename compute_attr_,
typename perf_tuning_knob_,
27 typename dtype_scale_,
typename dtype_zero_pt_,
int dequant_s_,
31template <
typename compute_attr_,
typename perf_tuning_knob_,
32 typename dtype_scale_,
typename dtype_zero_pt_,
int dequant_s_>
34 dtype_scale_, dtype_zero_pt_, dequant_s_,
gpu_arch::
Xe> {
37 static constexpr int k_stride = perf_tuning_knob::k_stride;
38 static constexpr int stages = perf_tuning_knob::stages;
39 static constexpr int sync_freq = perf_tuning_knob::sync_freq;
45 static constexpr uint32_t block_bytes_x_a = 32;
46 static constexpr uint32_t block_size_y_a = 16;
48 static constexpr bool is_int4_matB_policy =
true;
50 static constexpr uint32_t block_size_x_b = 16;
51 static constexpr uint32_t block_bytes_y_b = 32;
52 static_assert(block_bytes_x_a == block_bytes_y_b,
53 "mat_a x need to match with mat_b y");
55 static constexpr uint32_t dequant_s = dequant_s_;
56 static_assert((dequant_s % (32 /
sizeof(
dtype_mma_b))) == 0,
57 "dequant_s should be a multiply of 32B");
Definition limitation.hpp:607
gpu_arch
Definition common.hpp:73
typename compute_attr::dtype_acc dtype_mma_acc
Definition compute_policy.hpp:41
typename compute_attr::dtype_b dtype_mma_b
Definition compute_policy.hpp:43
compute_attr_ compute_attr
Definition compute_policy.hpp:35
dtype_zero_pt_ dtype_zero_pt
Definition compute_policy.hpp:59
dtype_scale_ dtype_scale
Definition compute_policy.hpp:58
perf_tuning_knob_ perf_tuning_knob
Definition compute_policy.hpp:36
typename compute_attr::dtype_a dtype_mma_a
Definition compute_policy.hpp:42
Definition compute_policy.hpp:29