XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
scaled_dot_product_attention.cpp File Reference
#include "softmax.hpp"
#include "tests/utils/utils.hpp"
Include dependency graph for scaled_dot_product_attention.cpp:

Macros

#define SIMD   32
 
#define batch_num   16
 
#define head_num   16
 
#define sequence_len   512
 
#define head_size   64
 

Functions

template<typename dtype_in , typename dtype_out , typename data_type_acc = float>
int sdp_fwd_result_validate (dtype_in *q_device, dtype_in *k_device, dtype_in *v_device, dtype_in *mask_device, dtype_out *c_device, uint32_t qk_m, uint32_t qk_k, uint32_t qk_n, uint32_t sv_m, uint32_t sv_k, uint32_t sv_n, uint32_t batch_cnt, sycl::queue &queue, mem_layout mem_layout_qk_a_=mem_layout::row_major, mem_layout mem_layout_qk_b_=mem_layout::row_major, mem_layout mem_layout_sv_a_=mem_layout::row_major, mem_layout mem_layout_sv_b_=mem_layout::row_major)
 
void sdp_fwd_run (uint32_t iter)
 
int main ()
 

Macro Definition Documentation

◆ batch_num

#define batch_num   16

◆ head_num

#define head_num   16

◆ head_size

#define head_size   64

◆ sequence_len

#define sequence_len   512

◆ SIMD

#define SIMD   32

Function Documentation

◆ main()

int main ( )

◆ sdp_fwd_result_validate()

template<typename dtype_in , typename dtype_out , typename data_type_acc = float>
int sdp_fwd_result_validate ( dtype_in *  q_device,
dtype_in *  k_device,
dtype_in *  v_device,
dtype_in *  mask_device,
dtype_out *  c_device,
uint32_t  qk_m,
uint32_t  qk_k,
uint32_t  qk_n,
uint32_t  sv_m,
uint32_t  sv_k,
uint32_t  sv_n,
uint32_t  batch_cnt,
sycl::queue &  queue,
mem_layout  mem_layout_qk_a_ = mem_layout::row_major,
mem_layout  mem_layout_qk_b_ = mem_layout::row_major,
mem_layout  mem_layout_sv_a_ = mem_layout::row_major,
mem_layout  mem_layout_sv_b_ = mem_layout::row_major 
)

◆ sdp_fwd_run()

void sdp_fwd_run ( uint32_t  iter)