#include <batch_gemm.hpp>
|
| static __XETLA_API constexpr uint32_t | get_barrier_count () |
| | Gets named_barrier id consumption count.
|
| |
| static __XETLA_API constexpr uint32_t | get_slm_size () |
| | Gets local memory size consumption.
|
| |
| static cl::sycl::range< 3 > | get_local_range () |
| | Host helper function to get the expected local range under the current BATCH_GEMM config.
|
| |
| static cl::sycl::range< 3 > | get_group_range (uint32_t batch_size, uint32_t matrix_m, uint32_t matrix_n) |
| | Host helper function to get the expected group range under the current BATCH_GEMM config.
|
| |
| static cl::sycl::nd_range< 3 > | get_nd_range (arguments_t &args) |
| | Host helper function to get the expected nd_range under the current BATCH_GEMM config.
|
| |
| static bool | can_implement (arguments_t &args) |
| | Check if the arguments can be implemented.
|
| |
◆ can_implement()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Check if the arguments can be implemented.
- Parameters
-
| args | Is the BATCH_GEMM arguments for application-related runtime variables. |
- Returns
- Check result.
◆ get_barrier_count()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Gets named_barrier id consumption count.
Users query and get a named_barrier id consumption count in compile time.
- Returns
- The count of named barriers required.
◆ get_group_range()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
| static cl::sycl::range< 3 > gpu::xetla::kernel::batch_gemm_t< gemm_t_, epilogue_t_, arch_tag_ >::get_group_range |
( |
uint32_t |
batch_size, |
|
|
uint32_t |
matrix_m, |
|
|
uint32_t |
matrix_n |
|
) |
| |
|
inlinestatic |
Host helper function to get the expected group range under the current BATCH_GEMM config.
- Parameters
-
| matrix_m | Is the size of the m dimension of the matrix multiplication (m x k x n). |
| matrix_n | Is the size of the n dimension of the matrix multiplication (m x k x n). |
- Returns
- Expected group range.
◆ get_local_range()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Host helper function to get the expected local range under the current BATCH_GEMM config.
- Returns
- Expected local range.
◆ get_nd_range()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Host helper function to get the expected nd_range under the current BATCH_GEMM config.
- Parameters
-
| args | Is the BATCH_GEMM arguments for application-related runtime variables. |
- Returns
- Expected nd_range.
◆ get_slm_size()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Gets local memory size consumption.
Users query and get a local memory consumption size in compile time.
- Returns
- The size of local memory required.
◆ operator()()
template<typename gemm_t_ , typename epilogue_t_ ,
gpu_arch arch_tag_>
Main execution function for BATCH_GEMM.
The processing order is 1) set group-level base and boundary -> 2) gemm -> 3) epilogue.
- Parameters
-
| item | Is the sycl::nd_item, returns execution related information, such as workgroup id, subgroup id... |
| args | Is the BATCH_GEMM arguments for application-related runtime variables. |
| slm_base | Is the slm base address. |
| nbarrier_base | Is the named barrier base. |