StreamK GEMM implementation. More...
#include <dispatch_policy.hpp>

Public Member Functions | |
| cl::sycl::range< 3 > | get_group_range () const |
| Host helper function to get the expected nd_range under the current GEMM config. | |
| void | get_sk_workgroups (int &sk_groups, int &savings_iters, int sk_tiles, int iters_per_tile, int avail_xecores, bool allow_partial_wave) const |
| Host helper function to compute sk_groups to dispatch for a given number of sk_tiles. | |
| void | get_groups (int &dp_tiles, int &sk_groups, int output_tiles, int iters_per_tile, int avail_xecores) |
| Determine the populations of DP and SK groups to invoke for the given number of output tiles. | |
| dispatch_policy_stream_k ()=default | |
| Constructor. | |
| dispatch_policy_stream_k (uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_, uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_, uint32_t avail_xecores_=arch_attr_t< arch_tag >::max_wg_num) | |
| Set for device copyable. | |
| int | get_num_active_groups () const |
| Host helper function to return number of groups after stream_k split. | |
| __XETLA_API KERNEL_FUNC int | get_iters_per_tile () const |
| Kernel helper function to return number of K-iters per output tile. | |
| __XETLA_API KERNEL_FUNC int | get_sk_iters_per_normal_group () const |
| Kernel helper function to return number of K-iters for normal sk groups. | |
| __XETLA_API KERNEL_FUNC int | get_sk_regions () const |
| Kernel helper function to return number of SK regions. | |
| __XETLA_API KERNEL_FUNC int | get_sk_groups_per_region () const |
| Kernel helper function to return number of SK groups per region. | |
| __XETLA_API KERNEL_FUNC void | get_tile_offsets (int tile_idx, int &tile_offset_m, int &tile_offset_n) const |
| Kernel function to get tile offset for m and n. | |
| __XETLA_API KERNEL_FUNC int | get_sk_tile_idx (int iter) const |
| Kernel function to return tile idx for current sk iteration. | |
| __XETLA_API KERNEL_FUNC void | get_iter_extents (int sk_group_idx, int &group_iter_begin, int &group_iter_end) const |
| Kernel function to get iteration extends for stream_k split. | |
| __XETLA_API KERNEL_FUNC int | get_first_group_idx (int tile_idx, int group_idx) const |
| kernel function to get the first sk group index writing the sliced output tile; | |
Public Attributes | |
| uint32_t | matrix_m |
| uint32_t | matrix_k |
| uint32_t | matrix_n |
| uint32_t | wg_tile_m |
| uint32_t | wg_tile_k |
| uint32_t | wg_tile_n |
| uint32_t | sg_tile_m |
| uint32_t | sg_tile_n |
| uint32_t | avail_xecores |
| Number of xecores available for stream_k load balancing. | |
| uint32_t | num_workgroups |
| uint32_t | dp_groups |
| uint32_t | sk_tiles |
| Number of data-parallel workgroups. | |
| uint32_t | sk_waves |
| uint32_t | sk_big_groups_per_region |
| uint32_t | sk_iters_per_region |
| uint32_t | sk_regions |
| uint32_t | sk_groups_per_region |
| FastDivMod | div_mod_tiles_m |
| FastDivMod | div_mod_tiles_n |
| FastDivMod | div_mod_iters_per_tile |
| FastDivMod | div_mod_sk_regions |
| FastDivMod | div_mod_sk_groups_per_region |
| FastDivMod | div_mod_sk_iters_per_normal_group |
| FastDivMod | div_mod_sk_iters_per_region |
| FastDivMod | div_mod_sk_iters_per_big_group |
Static Public Attributes | |
| static constexpr gpu_arch | arch_tag = arch_tag_ |
| static int const | kMinItersPerSkGroup = 2 |
| Minimum number of MAC-iterations per streamk group. | |
StreamK GEMM implementation.
A special GEMM implementation to avoid tail effects when GEMM shape does not fit the machine Implements variable K-slicing for effective load-balancing and performs inter-group reduction. Implementation loosely based on this paper - https://arxiv.org/pdf/2301.03598.pdf
| arch_tag_ | Is the HW architecture. |
|
inlinedefault |
Constructor.
|
inline |
Set for device copyable.
|
inline |
kernel function to get the first sk group index writing the sliced output tile;
|
inline |
Host helper function to get the expected nd_range under the current GEMM config.
|
inline |
Determine the populations of DP and SK groups to invoke for the given number of output tiles.
|
inline |
Kernel function to get iteration extends for stream_k split.
|
inline |
Kernel helper function to return number of K-iters per output tile.
|
inline |
Host helper function to return number of groups after stream_k split.
|
inline |
Kernel helper function to return number of SK groups per region.
|
inline |
Kernel helper function to return number of K-iters for normal sk groups.
|
inline |
Kernel helper function to return number of SK regions.
|
inline |
Kernel function to return tile idx for current sk iteration.
|
inline |
Host helper function to compute sk_groups to dispatch for a given number of sk_tiles.
| [out] | savings_iters | |
| [out] | sk_tiles |
|
inline |
Kernel function to get tile offset for m and n.
|
staticconstexpr |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::avail_xecores |
Number of xecores available for stream_k load balancing.
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_iters_per_tile |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_groups_per_region |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_big_group |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_normal_group |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_region |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_regions |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_tiles_m |
| FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_tiles_n |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::dp_groups |
|
static |
Minimum number of MAC-iterations per streamk group.
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_k |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_m |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_n |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::num_workgroups |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sg_tile_m |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sg_tile_n |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_big_groups_per_region |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_groups_per_region |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_iters_per_region |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_regions |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_tiles |
Number of data-parallel workgroups.
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_waves |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_k |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_m |
| uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_n |