XeTLA v0.3.6
IntelĀ® Xe Templates for Linear Algebra - API Definition Document
 
Loading...
Searching...
No Matches
gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd > Struct Template Reference

#include <rand.hpp>

Collaboration diagram for gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >:

Public Member Functions

__XETLA_API void init (uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_)
 
template<typename dtype >
__XETLA_API xetla_vector< dtype, SZ > process (xetla_vector< dtype, SZ > input)
 
__XETLA_API xetla_vector< dtype_mask, SZ > get_mask ()
 

Public Attributes

xetla_rand_t< random_simd > rand_gen
 
xetla_vector< dtype_mask, SZ > mask
 
uint32_t threshold
 
float scale
 

Static Public Attributes

static constexpr uint32_t random_len = 4 * random_simd
 

Member Function Documentation

◆ get_mask()

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
__XETLA_API xetla_vector< dtype_mask, SZ > gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::get_mask ( )
inline

◆ init()

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
__XETLA_API void gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::init ( uint64_t  seed,
uint64_t  subseq,
uint64_t  offset,
uint32_t  threshold_,
float  scale_ 
)
inline

◆ process()

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
template<typename dtype >
__XETLA_API xetla_vector< dtype, SZ > gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::process ( xetla_vector< dtype, SZ >  input)
inline

Member Data Documentation

◆ mask

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
xetla_vector<dtype_mask, SZ> gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::mask

◆ rand_gen

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
xetla_rand_t<random_simd> gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::rand_gen

◆ random_len

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
constexpr uint32_t gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::random_len = 4 * random_simd
staticconstexpr

◆ scale

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
float gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::scale

◆ threshold

template<uint32_t SZ, typename dtype_mask = uint8_t, uint32_t random_simd = 16>
uint32_t gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::threshold