#include <rand.hpp>

Public Member Functions | |
| __XETLA_API void | init (uint64_t seed, uint64_t subseq, uint64_t offset, uint32_t threshold_, float scale_) |
| template<typename dtype > | |
| __XETLA_API xetla_vector< dtype, SZ > | process (xetla_vector< dtype, SZ > input) |
| __XETLA_API xetla_vector< dtype_mask, SZ > | get_mask () |
Public Attributes | |
| xetla_rand_t< random_simd > | rand_gen |
| xetla_vector< dtype_mask, SZ > | mask |
| uint32_t | threshold |
| float | scale |
Static Public Attributes | |
| static constexpr uint32_t | random_len = 4 * random_simd |
|
inline |
|
inline |
|
inline |
| xetla_vector<dtype_mask, SZ> gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::mask |
| xetla_rand_t<random_simd> gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::rand_gen |
|
staticconstexpr |
| float gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::scale |
| uint32_t gpu::xetla::dropout_fwd_t< SZ, dtype_mask, random_simd >::threshold |