21#include <ext/intel/esimd.hpp>
27#if defined DEBUG && defined LOG_PRINT
30static constexpr size_t reg_start = 128 * 64;
33using element_type = uint16_t;
34static constexpr size_t element_num = 8;
35static constexpr size_t max_dims = 3;
36static constexpr size_t dims_pos = 0;
37static constexpr size_t dims_global_start = 1;
38static constexpr size_t dims_local_start = 1 + max_dims;
40static constexpr size_t nd_item_offset
41 = reg_start - element_num *
sizeof(element_type);
42static inline ESIMD_PRIVATE ESIMD_REGISTER(nd_item_offset)
43 __ESIMD_NS::simd<element_type, element_num> saved_nd_item;
46static inline void set(sycl::nd_item<dims> item) {
47 static_assert(dims <= max_dims);
49 saved_nd_item[dims_pos] = dims;
52 for (
auto i = 0; i < dims; i++) {
53 saved_nd_item[dims_global_start + i] = item.get_group(i);
57 for (
auto i = 0; i < dims; i++) {
58 saved_nd_item[dims_local_start + i] = item.get_local_id(i);
62static inline uint16_t get_dims() {
63 return saved_nd_item[dims_pos];
66static inline int16_t get_group_id(
size_t dim) {
67 return saved_nd_item[dims_global_start + dim];
70static inline int16_t get_local_id(
size_t dim) {
71 return saved_nd_item[dims_local_start + dim];
80static constexpr size_t exit_offset = reg_start - 8 *
sizeof(int);
81ESIMD_PRIVATE ESIMD_REGISTER(exit_offset) __ESIMD_NS::simd<int, 8> reg_exit;
82ESIMD_INLINE
void xetla_thread_exit() {
83 constexpr uint32_t exDesc = 0x0;
84 constexpr uint32_t desc = 0x02000010;
85 constexpr uint8_t execSize = 0x83;
86 constexpr uint8_t sfid = 0x3;
87 constexpr uint8_t numSrc0 = 0x1;
88 constexpr uint8_t numSrc1 = 0x0;
89 constexpr uint8_t isEOT = 0x1;
90 return sycl::ext::intel::experimental::esimd::raw_send(
91 reg_exit, exDesc, desc, execSize, sfid, numSrc0, isEOT);
99#define STR_APPEND(a, b, c) a b c
100#ifdef __SYCL_DEVICE_ONLY__
103#define XETLA_PRINTF(s, ...) \
105 const __attribute__((opencl_constant)) char f[] = STR_APPEND( \
106 "[XeTLA] [KERNEL] [group(%d, %d, %d), local(%d, " \
109 sycl::ext::oneapi::experimental::printf(f, \
110 debug_ctx::nd_item::get_group_id(0), \
111 debug_ctx::nd_item::get_group_id(1), \
112 debug_ctx::nd_item::get_group_id(2), \
113 debug_ctx::nd_item::get_local_id(0), \
114 debug_ctx::nd_item::get_local_id(1), \
115 debug_ctx::nd_item::get_local_id(2), ##__VA_ARGS__); \
118#define XETLA_PRINTF(s, ...) \
120 const __attribute__((opencl_constant)) char f[] \
121 = STR_APPEND("[XeTLA] [KERNEL] : ", s, "\n"); \
122 sycl::ext::oneapi::experimental::printf(f, ##__VA_ARGS__); \
127#define XETLA_PRINTF(s, ...) \
129 const char *f = STR_APPEND("[XeTLA] [HOST] : ", s, "\n"); \
130 printf(f, ##__VA_ARGS__); \
136#define XETLA_PRINTF(s, ...) \
143#ifdef __SYCL_DEVICE_ONLY__
145#define XETLA_ASSERT(c, s, ...) \
152#define XETLA_ASSERT(c, s, ...) \
154 if (!(c)) { XETLA_PRINTF(s, ##__VA_ARGS__); } \
158#define XETLA_ASSERT(c, s, ...) \
167enum class dbg_level : uint8_t {
173#define DEBUG_INVOKE(level, ...) \
175 if constexpr (DEBUG >= static_cast<uint8_t>(level)) { \
176 if (!(__VA_ARGS__)) { XETLA_PRINTF("L%d: " #__VA_ARGS__, level); } \
180#define DEBUG_INVOKE(level, ...) \
set(TARGET gemm_universal) add_executable($
Definition CMakeLists.txt:1
Definition arch_config.hpp:24