XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ > Struct Template Reference

StreamK GEMM implementation. More...

#include <dispatch_policy.hpp>

Collaboration diagram for gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >:

Public Member Functions

cl::sycl::range< 3 > get_group_range () const
 Host helper function to get the expected nd_range under the current GEMM config.
 
void get_sk_workgroups (int &sk_groups, int &savings_iters, int sk_tiles, int iters_per_tile, int avail_xecores, bool allow_partial_wave) const
 Host helper function to compute sk_groups to dispatch for a given number of sk_tiles.
 
void get_groups (int &dp_tiles, int &sk_groups, int output_tiles, int iters_per_tile, int avail_xecores)
 Determine the populations of DP and SK groups to invoke for the given number of output tiles.
 
 dispatch_policy_stream_k ()=default
 Constructor.
 
 dispatch_policy_stream_k (uint32_t matrix_m_, uint32_t matrix_k_, uint32_t matrix_n_, uint32_t wg_tile_m_, uint32_t wg_tile_k_, uint32_t wg_tile_n_, uint32_t sg_tile_m_, uint32_t sg_tile_n_, uint32_t avail_xecores_=arch_attr_t< arch_tag >::max_wg_num)
 Set for device copyable.
 
int get_num_active_groups () const
 Host helper function to return number of groups after stream_k split.
 
__XETLA_API KERNEL_FUNC int get_iters_per_tile () const
 Kernel helper function to return number of K-iters per output tile.
 
__XETLA_API KERNEL_FUNC int get_sk_iters_per_normal_group () const
 Kernel helper function to return number of K-iters for normal sk groups.
 
__XETLA_API KERNEL_FUNC int get_sk_regions () const
 Kernel helper function to return number of SK regions.
 
__XETLA_API KERNEL_FUNC int get_sk_groups_per_region () const
 Kernel helper function to return number of SK groups per region.
 
__XETLA_API KERNEL_FUNC void get_tile_offsets (int tile_idx, int &tile_offset_m, int &tile_offset_n) const
 Kernel function to get tile offset for m and n.
 
__XETLA_API KERNEL_FUNC int get_sk_tile_idx (int iter) const
 Kernel function to return tile idx for current sk iteration.
 
__XETLA_API KERNEL_FUNC void get_iter_extents (int sk_group_idx, int &group_iter_begin, int &group_iter_end) const
 Kernel function to get iteration extends for stream_k split.
 
__XETLA_API KERNEL_FUNC int get_first_group_idx (int tile_idx, int group_idx) const
 kernel function to get the first sk group index writing the sliced output tile;
 

Public Attributes

uint32_t matrix_m
 
uint32_t matrix_k
 
uint32_t matrix_n
 
uint32_t wg_tile_m
 
uint32_t wg_tile_k
 
uint32_t wg_tile_n
 
uint32_t sg_tile_m
 
uint32_t sg_tile_n
 
uint32_t avail_xecores
 Number of xecores available for stream_k load balancing.
 
uint32_t num_workgroups
 
uint32_t dp_groups
 
uint32_t sk_tiles
 Number of data-parallel workgroups.
 
uint32_t sk_waves
 
uint32_t sk_big_groups_per_region
 
uint32_t sk_iters_per_region
 
uint32_t sk_regions
 
uint32_t sk_groups_per_region
 
FastDivMod div_mod_tiles_m
 
FastDivMod div_mod_tiles_n
 
FastDivMod div_mod_iters_per_tile
 
FastDivMod div_mod_sk_regions
 
FastDivMod div_mod_sk_groups_per_region
 
FastDivMod div_mod_sk_iters_per_normal_group
 
FastDivMod div_mod_sk_iters_per_region
 
FastDivMod div_mod_sk_iters_per_big_group
 

Static Public Attributes

static constexpr gpu_arch arch_tag = arch_tag_
 
static int const kMinItersPerSkGroup = 2
 Minimum number of MAC-iterations per streamk group.
 

Detailed Description

template<gpu_arch arch_tag_ = gpu_arch::Xe>
struct gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >

StreamK GEMM implementation.

A special GEMM implementation to avoid tail effects when GEMM shape does not fit the machine Implements variable K-slicing for effective load-balancing and performs inter-group reduction. Implementation loosely based on this paper - https://arxiv.org/pdf/2301.03598.pdf

Template Parameters
arch_tag_Is the HW architecture.

Constructor & Destructor Documentation

◆ dispatch_policy_stream_k() [1/2]

template<gpu_arch arch_tag_ = gpu_arch::Xe>
gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::dispatch_policy_stream_k ( )
inlinedefault

Constructor.

◆ dispatch_policy_stream_k() [2/2]

template<gpu_arch arch_tag_ = gpu_arch::Xe>
gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::dispatch_policy_stream_k ( uint32_t  matrix_m_,
uint32_t  matrix_k_,
uint32_t  matrix_n_,
uint32_t  wg_tile_m_,
uint32_t  wg_tile_k_,
uint32_t  wg_tile_n_,
uint32_t  sg_tile_m_,
uint32_t  sg_tile_n_,
uint32_t  avail_xecores_ = arch_attr_t<arch_tag>::max_wg_num 
)
inline

Set for device copyable.

Member Function Documentation

◆ get_first_group_idx()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_first_group_idx ( int  tile_idx,
int  group_idx 
) const
inline

kernel function to get the first sk group index writing the sliced output tile;

◆ get_group_range()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
cl::sycl::range< 3 > gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_group_range ( ) const
inline

Host helper function to get the expected nd_range under the current GEMM config.

Returns
Expected nd_range.

◆ get_groups()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
void gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_groups ( int &  dp_tiles,
int &  sk_groups,
int  output_tiles,
int  iters_per_tile,
int  avail_xecores 
)
inline

Determine the populations of DP and SK groups to invoke for the given number of output tiles.

◆ get_iter_extents()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC void gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_iter_extents ( int  sk_group_idx,
int &  group_iter_begin,
int &  group_iter_end 
) const
inline

Kernel function to get iteration extends for stream_k split.

◆ get_iters_per_tile()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_iters_per_tile ( ) const
inline

Kernel helper function to return number of K-iters per output tile.

◆ get_num_active_groups()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_num_active_groups ( ) const
inline

Host helper function to return number of groups after stream_k split.

◆ get_sk_groups_per_region()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_sk_groups_per_region ( ) const
inline

Kernel helper function to return number of SK groups per region.

◆ get_sk_iters_per_normal_group()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_sk_iters_per_normal_group ( ) const
inline

Kernel helper function to return number of K-iters for normal sk groups.

◆ get_sk_regions()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_sk_regions ( ) const
inline

Kernel helper function to return number of SK regions.

◆ get_sk_tile_idx()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC int gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_sk_tile_idx ( int  iter) const
inline

Kernel function to return tile idx for current sk iteration.

◆ get_sk_workgroups()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
void gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_sk_workgroups ( int &  sk_groups,
int &  savings_iters,
int  sk_tiles,
int  iters_per_tile,
int  avail_xecores,
bool  allow_partial_wave 
) const
inline

Host helper function to compute sk_groups to dispatch for a given number of sk_tiles.

Parameters
[out]savings_iters
[out]sk_tiles

◆ get_tile_offsets()

template<gpu_arch arch_tag_ = gpu_arch::Xe>
__XETLA_API KERNEL_FUNC void gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::get_tile_offsets ( int  tile_idx,
int &  tile_offset_m,
int &  tile_offset_n 
) const
inline

Kernel function to get tile offset for m and n.

Member Data Documentation

◆ arch_tag

template<gpu_arch arch_tag_ = gpu_arch::Xe>
constexpr gpu_arch gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::arch_tag = arch_tag_
staticconstexpr

◆ avail_xecores

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::avail_xecores

Number of xecores available for stream_k load balancing.

◆ div_mod_iters_per_tile

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_iters_per_tile

◆ div_mod_sk_groups_per_region

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_groups_per_region

◆ div_mod_sk_iters_per_big_group

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_big_group

◆ div_mod_sk_iters_per_normal_group

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_normal_group

◆ div_mod_sk_iters_per_region

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_iters_per_region

◆ div_mod_sk_regions

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_sk_regions

◆ div_mod_tiles_m

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_tiles_m

◆ div_mod_tiles_n

template<gpu_arch arch_tag_ = gpu_arch::Xe>
FastDivMod gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::div_mod_tiles_n

◆ dp_groups

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::dp_groups

◆ kMinItersPerSkGroup

template<gpu_arch arch_tag_ = gpu_arch::Xe>
int const gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::kMinItersPerSkGroup = 2
static

Minimum number of MAC-iterations per streamk group.

◆ matrix_k

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_k

◆ matrix_m

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_m

◆ matrix_n

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::matrix_n

◆ num_workgroups

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::num_workgroups

◆ sg_tile_m

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sg_tile_m

◆ sg_tile_n

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sg_tile_n

◆ sk_big_groups_per_region

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_big_groups_per_region

◆ sk_groups_per_region

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_groups_per_region

◆ sk_iters_per_region

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_iters_per_region

◆ sk_regions

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_regions

◆ sk_tiles

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_tiles

Number of data-parallel workgroups.

◆ sk_waves

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::sk_waves

◆ wg_tile_k

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_k

◆ wg_tile_m

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_m

◆ wg_tile_n

template<gpu_arch arch_tag_ = gpu_arch::Xe>
uint32_t gpu::xetla::kernel::dispatch_policy_stream_k< arch_tag_ >::wg_tile_n