XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ > Class Template Reference

#include <batch_gemm.hpp>

Classes

struct  arguments_t
 BATCH_GEMM arguments. More...
 

Public Member Functions

__XETLA_API KERNEL_FUNC void operator() (sycl::nd_item< 3 > &item, const arguments_t &args, uint32_t slm_base=0, uint32_t nbarrier_base=0)
 Main execution function for BATCH_GEMM.
 

Static Public Member Functions

static __XETLA_API constexpr uint32_t get_barrier_count ()
 Gets named_barrier id consumption count.
 
static __XETLA_API constexpr uint32_t get_slm_size ()
 Gets local memory size consumption.
 
static cl::sycl::range< 3 > get_local_range ()
 Host helper function to get the expected local range under the current BATCH_GEMM config.
 
static cl::sycl::range< 3 > get_group_range (uint32_t batch_size, uint32_t matrix_m, uint32_t matrix_n)
 Host helper function to get the expected group range under the current BATCH_GEMM config.
 
static cl::sycl::nd_range< 3 > get_nd_range (arguments_t &args)
 Host helper function to get the expected nd_range under the current BATCH_GEMM config.
 
static bool can_implement (arguments_t &args)
 Check if the arguments can be implemented.
 

Member Function Documentation

◆ can_implement()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static bool gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::can_implement ( arguments_t args)
inlinestatic

Check if the arguments can be implemented.

Parameters
argsIs the BATCH_GEMM arguments for application-related runtime variables.
Returns
Check result.

◆ get_barrier_count()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static __XETLA_API constexpr uint32_t gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_barrier_count ( )
inlinestaticconstexpr

Gets named_barrier id consumption count.

Users query and get a named_barrier id consumption count in compile time.

Returns
The count of named barriers required.

◆ get_group_range()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static cl::sycl::range< 3 > gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_group_range ( uint32_t  batch_size,
uint32_t  matrix_m,
uint32_t  matrix_n 
)
inlinestatic

Host helper function to get the expected group range under the current BATCH_GEMM config.

Parameters
matrix_mIs the size of the m dimension of the matrix multiplication (m x k x n).
matrix_nIs the size of the n dimension of the matrix multiplication (m x k x n).
Returns
Expected group range.

◆ get_local_range()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static cl::sycl::range< 3 > gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_local_range ( )
inlinestatic

Host helper function to get the expected local range under the current BATCH_GEMM config.

Returns
Expected local range.

◆ get_nd_range()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static cl::sycl::nd_range< 3 > gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_nd_range ( arguments_t args)
inlinestatic

Host helper function to get the expected nd_range under the current BATCH_GEMM config.

Parameters
argsIs the BATCH_GEMM arguments for application-related runtime variables.
Returns
Expected nd_range.

◆ get_slm_size()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
static __XETLA_API constexpr uint32_t gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_slm_size ( )
inlinestaticconstexpr

Gets local memory size consumption.

Users query and get a local memory consumption size in compile time.

Returns
The size of local memory required.

◆ operator()()

template<typename gemm_t_ , typename epilogue_t_ , gpu_arch arch_tag_>
__XETLA_API KERNEL_FUNC void gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::operator() ( sycl::nd_item< 3 > &  item,
const arguments_t args,
uint32_t  slm_base = 0,
uint32_t  nbarrier_base = 0 
)
inline

Main execution function for BATCH_GEMM.

The processing order is 1) set group-level base and boundary -> 2) gemm -> 3) epilogue.

Parameters
itemIs the sycl::nd_item, returns execution related information, such as workgroup id, subgroup id...
argsIs the BATCH_GEMM arguments for application-related runtime variables.
slm_baseIs the slm base address.
nbarrier_baseIs the named barrier base.