22#include "kernel/gemm/common.hpp"
31template <gpu_arch arch_tag_>
40 return item.get_group(idx);
44 [[maybe_unused]] uint32_t &group_range_m,
45 [[maybe_unused]] uint32_t &group_range_n) {}
52template <
int wg_num_n_, gpu_arch arch_tag_>
61 sycl::nd_item<3> &item) {
62 return item.get_group(idx);
67 sycl::nd_item<3> &item) {
68 uint32_t group_range_n = item.get_group_range(2);
69 uint32_t wg_repeat_n = group_range_n / wg_num_n;
71 uint32_t repeat_id_m = repeat_id / wg_repeat_n;
72 uint32_t repeat_start_m = repeat_id_m * wg_num_m;
74 uint32_t wg_coord_m = wg_inner_id / wg_num_n;
75 int start_m_id = repeat_start_m + wg_coord_m;
81 sycl::nd_item<3> &item) {
82 uint32_t group_range_n = item.get_group_range(2);
83 uint32_t wg_repeat_n = group_range_n / wg_num_n;
85 uint32_t repeat_id_n = repeat_id % wg_repeat_n;
86 uint32_t repeat_id_m = repeat_id / wg_repeat_n;
87 uint32_t repeat_start_n_0 = repeat_id_n * wg_num_n;
88 uint32_t repeat_start_n_1 = (wg_repeat_n - repeat_id_n - 1) * wg_num_n;
89 uint32_t repeat_start_n
90 = (repeat_id_m & 1) == 0 ? repeat_start_n_0 : repeat_start_n_1;
92 uint32_t wg_coord_n = wg_inner_id % wg_num_n;
93 int start_n_id = repeat_start_n + wg_coord_n;
99 uint32_t &group_range_m, uint32_t &group_range_n) {
100 group_range_m = (group_range_m + wg_num_m - 1) / wg_num_m * wg_num_m;
101 group_range_n = (group_range_n + wg_num_n - 1) / wg_num_n * wg_num_n;
106 static constexpr uint32_t wg_num_n = wg_num_n_;
107 static_assert(!(max_wg_num % wg_num_n),
108 "max_wg_num cannot be divisible by given wg_num_n!");
109 static constexpr uint32_t wg_num_m = max_wg_num / wg_num_n;
115template <
typename group_swizzle_policy_>
127template <
typename group_swizzle_policy_,
int global_ratio_ = 1,
128 int local_ratio_ = 1>
141template <gpu_arch arch_tag_ = gpu_arch::Xe>
188 cl::sycl::range<3> group_range
197 bool allow_partial_wave)
const {
199 savings_iters = INT_MIN;
204 int sk_iters =
sk_tiles * iters_per_tile;
207 int dp_equiv_iters = iters_per_tile * dp_equiv_waves;
209 int min_sk_groups = (allow_partial_wave)
215 for (
int trial_sk_groups = min_sk_groups;
216 trial_sk_groups <= max_sk_groups; trial_sk_groups++) {
220 int max_sk_iters_per_group
221 = (sk_iters + trial_sk_groups - 1) / trial_sk_groups;
222 int sk_iter_equiv = max_sk_iters_per_group *
sk_waves;
225 float iter_cost = 0.02f * float(num_peers) * float(sk_iter_equiv);
227 if (trial_sk_groups %
sk_tiles == 0) {
230 num_peers = (trial_sk_groups /
sk_tiles);
234 float peer_cost = 2.0f * float(num_peers);
235 float base_cost = 2.0f * float(
sk_waves);
237 int fixup_iter_equiv = int(base_cost + iter_cost + peer_cost);
239 int trial_savings_iter
240 = dp_equiv_iters - sk_iter_equiv - fixup_iter_equiv;
242 if (trial_savings_iter >= savings_iters) {
244 savings_iters = trial_savings_iter;
245 sk_groups = trial_sk_groups;
251 void get_groups(
int &dp_tiles,
int &sk_groups,
int output_tiles,
256 int partial_wave_tiles = output_tiles - full_wave_tiles;
258 if (partial_wave_tiles == 0) {
263 dp_tiles = output_tiles;
266 if (full_waves < 1) {
268 dp_tiles = full_wave_tiles;
275 dp_tiles = output_tiles;
289 std::cout <<
"SK Score: " << score <<
"\n\n";
294 dp_tiles = output_tiles;
305 uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_,
306 uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_,
321 int sk_iters_per_normal_group = 0;
322 int sk_iters_per_big_group = 0;
335 int output_tiles = num_tiles_m * num_tiles_n;
337 int dp_tiles = output_tiles;
341 get_groups(dp_tiles, sk_groups, output_tiles, iters_per_tile,
351 int sk_iters =
sk_tiles * iters_per_tile;
352 sk_groups = std::min(sk_groups, sk_iters);
355 sk_iters_per_normal_group = sk_iters / sk_groups;
357 = sk_iters - (sk_iters_per_normal_group * sk_groups);
358 int sk_big_groups = extra_sk_iters;
359 sk_iters_per_big_group = sk_iters_per_normal_group + 1;
362 uint32_t current_sk_gruops = sk_groups;
388 uint32_t total_tiles = num_tiles_m * num_tiles_n;
390 <<
", tiled_shape: (" << num_tiles_m <<
"," << num_tiles_n
392 <<
", tiles: " << total_tiles
393 <<
", dp_tiles: " << total_tiles -
sk_tiles
395 <<
", iters_per_tile: " << iters_per_tile
402 <<
", sk_iters_per_normal_group: "
403 << sk_iters_per_normal_group
439 int tile_idx,
int &tile_offset_m,
int &tile_offset_n)
const {
443 if (tiles_m > tiles_n) {
459 int &group_iter_begin,
int &group_iter_end)
const {
461 int group_idx_in_region;
463 region_idx, group_idx_in_region, sk_group_idx);
470 uint32_t current_group_idx_in_region = group_idx_in_region;
473 group_iter_begin += group_idx_in_region;
481 group_iter_end = group_iter_begin + group_iters;
486 int tile_idx,
int group_idx)
const {
487 uint32_t current_tile_idx = tile_idx;
495 int region_idx, iter_in_region;
498 region_idx, iter_in_region, iter);
506 int normal_group_iters = iter_in_region - big_group_iters;
508 uint32_t big_group_idx_in_region
514 int group_idx_in_region
516 ? big_group_idx_in_region
517 : normal_group_idx_in_region;
520 + group_idx_in_region;
522 return owning_group_idx;
#define __XETLA_API
Definition common.hpp:43
#define KERNEL_FUNC
KERNEL_FUNC macro.
Definition common.hpp:39
__XETLA_API uint32_t get_2d_group_linear_id(sycl::nd_item< 3 > &item)
get linear group id of the last two dimensions.
Definition misc.hpp:180
Definition limitation.hpp:734
gpu_arch
Definition common.hpp:73
Fast division + modulus operation Host code pre-computes values to avoid expensive operations in kern...
Definition fastmath.hpp:46
__XETLA_API KERNEL_FUNC void fast_divmod(int "ient, int &remainder, int dividend) const
Kernel side function to find quotient and remainder.
Definition fastmath.hpp:74
__XETLA_API KERNEL_FUNC int div(int dividend) const
kernel side utility functions for query of quotient
Definition fastmath.hpp:87
Definition arch_config.hpp:72
Default GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:116
group_swizzle_policy_ group_swizzle_policy
Definition dispatch_policy.hpp:117
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:118
Kslicing GEMM_UNIVERSAL implementation.
Definition dispatch_policy.hpp:129
static constexpr int local_ratio
Definition dispatch_policy.hpp:132
static constexpr int global_ratio
Definition dispatch_policy.hpp:131
group_swizzle_policy_ group_swizzle_policy
Definition dispatch_policy.hpp:130
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:133
StreamK GEMM implementation.
Definition dispatch_policy.hpp:142
FastDivMod div_mod_sk_iters_per_region
Definition dispatch_policy.hpp:177
FastDivMod div_mod_tiles_m
Definition dispatch_policy.hpp:171
void get_sk_workgroups(int &sk_groups, int &savings_iters, int sk_tiles, int iters_per_tile, int avail_xecores, bool allow_partial_wave) const
Host helper function to compute sk_groups to dispatch for a given number of sk_tiles.
Definition dispatch_policy.hpp:194
__XETLA_API KERNEL_FUNC int get_sk_tile_idx(int iter) const
Kernel function to return tile idx for current sk iteration.
Definition dispatch_policy.hpp:451
__XETLA_API KERNEL_FUNC void get_iter_extents(int sk_group_idx, int &group_iter_begin, int &group_iter_end) const
Kernel function to get iteration extends for stream_k split.
Definition dispatch_policy.hpp:458
FastDivMod div_mod_iters_per_tile
Definition dispatch_policy.hpp:173
uint32_t num_workgroups
Definition dispatch_policy.hpp:160
uint32_t sk_big_groups_per_region
Definition dispatch_policy.hpp:165
__XETLA_API KERNEL_FUNC void get_tile_offsets(int tile_idx, int &tile_offset_m, int &tile_offset_n) const
Kernel function to get tile offset for m and n.
Definition dispatch_policy.hpp:438
uint32_t matrix_n
Definition dispatch_policy.hpp:148
__XETLA_API KERNEL_FUNC int get_sk_iters_per_normal_group() const
Kernel helper function to return number of K-iters for normal sk groups.
Definition dispatch_policy.hpp:420
uint32_t sg_tile_n
Definition dispatch_policy.hpp:155
FastDivMod div_mod_sk_iters_per_big_group
Definition dispatch_policy.hpp:178
__XETLA_API KERNEL_FUNC int get_first_group_idx(int tile_idx, int group_idx) const
kernel function to get the first sk group index writing the sliced output tile;
Definition dispatch_policy.hpp:485
FastDivMod div_mod_tiles_n
Definition dispatch_policy.hpp:172
FastDivMod div_mod_sk_groups_per_region
Definition dispatch_policy.hpp:175
uint32_t wg_tile_n
Definition dispatch_policy.hpp:152
FastDivMod div_mod_sk_regions
Definition dispatch_policy.hpp:174
uint32_t sk_iters_per_region
Definition dispatch_policy.hpp:166
uint32_t wg_tile_k
Definition dispatch_policy.hpp:151
uint32_t matrix_m
Definition dispatch_policy.hpp:146
__XETLA_API KERNEL_FUNC int get_iters_per_tile() const
Kernel helper function to return number of K-iters per output tile.
Definition dispatch_policy.hpp:414
dispatch_policy_stream_k(uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_, uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_, uint32_t avail_xecores_=arch_attr_t< arch_tag >::max_wg_num)
Set for device copyable.
Definition dispatch_policy.hpp:304
uint32_t dp_groups
Definition dispatch_policy.hpp:161
__XETLA_API KERNEL_FUNC int get_sk_regions() const
Kernel helper function to return number of SK regions.
Definition dispatch_policy.hpp:426
dispatch_policy_stream_k()=default
Constructor.
cl::sycl::range< 3 > get_group_range() const
Host helper function to get the expected nd_range under the current GEMM config.
Definition dispatch_policy.hpp:187
FastDivMod div_mod_sk_iters_per_normal_group
Definition dispatch_policy.hpp:176
uint32_t sk_regions
Definition dispatch_policy.hpp:167
uint32_t sk_waves
Definition dispatch_policy.hpp:164
uint32_t sk_tiles
Number of data-parallel workgroups.
Definition dispatch_policy.hpp:163
int get_num_active_groups() const
Host helper function to return number of groups after stream_k split.
Definition dispatch_policy.hpp:409
uint32_t sg_tile_m
Definition dispatch_policy.hpp:154
uint32_t matrix_k
Definition dispatch_policy.hpp:147
static int const kMinItersPerSkGroup
Minimum number of MAC-iterations per streamk group.
Definition dispatch_policy.hpp:181
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:144
void get_groups(int &dp_tiles, int &sk_groups, int output_tiles, int iters_per_tile, int avail_xecores)
Determine the populations of DP and SK groups to invoke for the given number of output tiles.
Definition dispatch_policy.hpp:251
uint32_t wg_tile_m
Definition dispatch_policy.hpp:150
__XETLA_API KERNEL_FUNC int get_sk_groups_per_region() const
Kernel helper function to return number of SK groups per region.
Definition dispatch_policy.hpp:432
uint32_t sk_groups_per_region
Definition dispatch_policy.hpp:168
uint32_t avail_xecores
Number of xecores available for stream_k load balancing.
Definition dispatch_policy.hpp:158
Default GROUP_SWIZZLE implementation.
Definition dispatch_policy.hpp:32
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:34
group_swizzle_default()=default
static __XETLA_API void update_group_range(uint32_t &group_range_m, uint32_t &group_range_n)
Definition dispatch_policy.hpp:43
static __XETLA_API int get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:39
GROUP_SWIZZLE implementation of snake curve.
Definition dispatch_policy.hpp:53
static __XETLA_API void update_group_range(uint32_t &group_range_m, uint32_t &group_range_n)
Definition dispatch_policy.hpp:98
group_swizzle_snake()=default
static constexpr gpu_arch arch_tag
Definition dispatch_policy.hpp:55
static __XETLA_API std::enable_if_t< idx==0, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:60
static __XETLA_API std::enable_if_t< idx==2, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:80
static __XETLA_API std::enable_if_t< idx==1, int > get_tile_idx(sycl::nd_item< 3 > &item)
Definition dispatch_policy.hpp:66