XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
compute_policy.hpp
Go to the documentation of this file.
1/*******************************************************************************
2* Copyright (c) 2022-2023 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
19
20#pragma once
21
23
24namespace gpu::xetla::group {
25
26template <typename compute_attr_, typename perf_tuning_knob_,
27 typename dtype_scale_, typename dtype_zero_pt_, int dequant_s_,
28 gpu_arch arch_tag_ = gpu_arch::Xe>
30
31template <typename compute_attr_, typename perf_tuning_knob_,
32 typename dtype_scale_, typename dtype_zero_pt_, int dequant_s_>
33struct compute_policy_int4_dequantize_xmx<compute_attr_, perf_tuning_knob_,
34 dtype_scale_, dtype_zero_pt_, dequant_s_, gpu_arch::Xe> {
35 using compute_attr = compute_attr_;
36 using perf_tuning_knob = perf_tuning_knob_;
37 static constexpr int k_stride = perf_tuning_knob::k_stride;
38 static constexpr int stages = perf_tuning_knob::stages;
39 static constexpr int sync_freq = perf_tuning_knob::sync_freq;
40 static constexpr gpu_arch arch_tag = gpu_arch::Xe;
41 using dtype_mma_acc = typename compute_attr::dtype_acc;
42 using dtype_mma_a = typename compute_attr::dtype_a;
43 using dtype_mma_b = typename compute_attr::dtype_b;
44
45 static constexpr uint32_t block_bytes_x_a = 32;
46 static constexpr uint32_t block_size_y_a = 16;
47
48 static constexpr bool is_int4_matB_policy = true;
49
50 static constexpr uint32_t block_size_x_b = 16;
51 static constexpr uint32_t block_bytes_y_b = 32;
52 static_assert(block_bytes_x_a == block_bytes_y_b,
53 "mat_a x need to match with mat_b y");
54
55 static constexpr uint32_t dequant_s = dequant_s_;
56 static_assert((dequant_s % (32 / sizeof(dtype_mma_b))) == 0,
57 "dequant_s should be a multiply of 32B");
58 using dtype_scale = dtype_scale_;
59 using dtype_zero_pt = dtype_zero_pt_;
60};
61
62} // namespace gpu::xetla::group
C++ API.
Definition limitation.hpp:607
gpu_arch
Definition common.hpp:73