25 return (n + d - 1) / d;
32 return (n - (((n % d) + d) % d)) / d;
38 return (d + (n % d)) % d;
44 const int CACHELINE_SIZE = 256;
45 return (size + CACHELINE_SIZE - 1) / CACHELINE_SIZE * CACHELINE_SIZE;
72template <
typename Ty,
int N>
81 tmp = sycl::ext::intel::esimd::unpack_mask<N>(mask_val);
85template <
typename dtype_acc, uint32_t N, uint32_t num_flag = 4,
86 typename dtype_mask = uint8_t>
90 constexpr uint32_t unroll_size = num_flag * 16;
93 for (uint32_t i = 0; i < N / unroll_size; i++) {
95 = mask.xetla_select<unroll_size, 1>(i * unroll_size) > 0;
96 out.xetla_select<unroll_size, 1>(i * unroll_size)
99 if constexpr (N % unroll_size != 0) {
100 constexpr uint32_t remain_len = N % unroll_size;
101 constexpr uint32_t remain_start = N / unroll_size * unroll_size;
103 = mask.xetla_select<remain_len, 1>(remain_start) > 0;
104 out.xetla_select<remain_len, 1>(remain_start).
xetla_merge(0, mask_flag);
109template <reduce_op reduce_kind,
typename dtype,
int size>
111 xetla_vector<dtype, size>>
116template <reduce_op reduce_kind,
typename dtype,
int size>
118 xetla_vector<dtype, size>>
123template <reduce_op reduce_kind,
typename dtype,
int size>
125 xetla_vector<dtype, size>>
129 out.xetla_merge(a, b, mask);
133template <reduce_op reduce_kind,
typename dtype,
int size>
135 xetla_vector<dtype, size>>
139 out.xetla_merge(a, b, mask);
143template <reduce_op reduce_kind,
typename dtype,
int N_x,
int N_y>
144__XETLA_API typename std::enable_if_t<N_y == 1, xetla_vector<dtype, N_x>>
148template <reduce_op reduce_kind,
typename dtype,
int N_x,
int N_y>
149__XETLA_API typename std::enable_if_t<(N_y > 1), xetla_vector<dtype, N_x>>
151 static_assert(((N_y) & (N_y - 1)) == 0,
"N_y should be power of 2");
154 in.xetla_select<N_x * N_y / 2, 1>(0),
155 in.xetla_select<N_x * N_y / 2, 1>(N_x * N_y / 2));
160template <reduce_op reduce_kind,
typename dtype,
int N_x,
int N_y>
161__XETLA_API typename std::enable_if_t<N_x == 1, xetla_vector<dtype, N_y>>
165template <reduce_op reduce_kind,
typename dtype,
int N_x,
int N_y>
166__XETLA_API typename std::enable_if_t<(N_x > 1), xetla_vector<dtype, N_y>>
168 static_assert(((N_x) & (N_x - 1)) == 0,
"N_x should be power of 2");
170 auto in_2d = in.xetla_format<dtype, N_y, N_x>();
172 in_2d.xetla_select<N_y, 1, N_x / 2, 1>(0, 0),
173 in_2d.xetla_select<N_y, 1, N_x / 2, 1>(0, N_x / 2));
181 return item.get_group(2) + item.get_group(1) * item.get_group_range(2);
#define SW_BARRIER()
SW_BARRIER, insert software scheduling barrier, for better code control.
Definition common.hpp:227
#define __XETLA_API
Definition common.hpp:43
#define xetla_merge
xetla merge.
Definition base_ops.hpp:60
__ESIMD_NS::simd_mask< N > xetla_mask_int
wrapper for xetla_mask_int.
Definition base_types.hpp:172
__ESIMD_NS::simd< native_type_t< Ty >, N > xetla_vector
wrapper for xetla_vector.
Definition base_types.hpp:149
__ESIMD_NS::simd_mask< N > xetla_mask
wrapper for xetla_mask.
Definition base_types.hpp:165
__XETLA_API std::enable_if_t< N_x==1, xetla_vector< dtype, N_y > > recur_col_reduce(xetla_vector< dtype, N_y > in)
Definition misc.hpp:162
__XETLA_API xetla_vector< dtype_acc, N > drop_out(xetla_vector< dtype_acc, N > in, xetla_vector< dtype_mask, N > mask, dtype_acc scale)
Definition misc.hpp:87
__XETLA_API xetla_vector< uint32_t, 4 > get_time_stamp()
Returns time stamp.
Definition misc.hpp:57
__XETLA_API std::enable_if_t< reduce_kind==reduce_op::sum, xetla_vector< dtype, size > > reduce_helper(xetla_vector< dtype, size > a, xetla_vector< dtype, size > b)
Definition misc.hpp:112
__XETLA_API std::enable_if_t< N_y==1, xetla_vector< dtype, N_x > > recur_row_reduce(xetla_vector< dtype, N_x > in)
Definition misc.hpp:145
__XETLA_API xetla_mask_int< N > xetla_mask_int_gen(uint32_t mask_val)
Definition misc.hpp:79
__XETLA_API xetla_vector< Ty, N > xetla_vector_gen(int InitVal, int Step)
xetla_vector generation.
Definition misc.hpp:73
__XETLA_API uint32_t get_2d_group_linear_id(sycl::nd_item< 3 > &item)
get linear group id of the last two dimensions.
Definition misc.hpp:180
Definition arch_config.hpp:24
__XETLA_API constexpr int div_round_down(int n, int d)
Definition misc.hpp:30
__XETLA_API constexpr int modulo(int n, int d)
Definition misc.hpp:37
__XETLA_API constexpr uint32_t div_round_up(uint32_t n, uint32_t d)
Definition misc.hpp:24
__XETLA_API constexpr uint32_t cacheline_align_up(size_t size)
Definition misc.hpp:42