Class Kernel

namespace jd

Note

jit classes in this file may not meet the coding convention and is to be refactored.

template<typename T, typename arg_t = void>
class proxy_base
#include <interface.hpp>

Proxy pattern. The proxy could interface to anything. Similar to onednn’s “struct handle”. oneapi/dnnl/dnnl.hpp:136.

Public Functions

inline proxy_base()
inline virtual ~proxy_base()
inline void reset_sp(const std::shared_ptr<const T> &sp)
inline const std::shared_ptr<const T> &get_sp() const

Protected Functions

virtual bool create_proxy_object(std::shared_ptr<const T> &result_ref, const arg_t &arg) = 0

Private Members

std::shared_ptr<const T> data_handle_
class kernel_desc_proxy : public jd::proxy_base<kernel_desc_t, operator_desc>
#include <interface.hpp>

Base proxy class, interfacing to the real/cached kernel_desc_t.

Subclassed by jd::attention_desc, jd::dynamic_quant_desc, jd::dynamic_quant_matmul_desc, jd::eltwiseop_desc, jd::gather_desc, jd::groupnorm_desc, jd::layernorm_ba_desc, jd::layernormalized_spmm_desc, jd::logsoftmax_desc, jd::mha_dense_desc, jd::slice_desc, jd::softmax_desc, jd::sparse_matmul_desc, jd::transpose_matmul_desc, jd::transpose_mha_desc

Public Functions

inline kernel_desc_proxy()
explicit kernel_desc_proxy(const operator_desc &op_desc)
inline virtual ~kernel_desc_proxy()
inline const jd::kernel_kind &kernel_kind() const

Protected Functions

virtual bool create_proxy_object(std::shared_ptr<const kernel_desc_t> &result_ref, const operator_desc &op_desc) override

Protected Attributes

const std::vector<impl_list_item_t> *impl_list_ = nullptr
class kernel_proxy : public jd::proxy_base<kernel_t, std::shared_ptr<const kernel_desc_t>>
#include <interface.hpp>

Base proxy class, interfacing to the real/cached kernel_t.

Subclassed by jd::attention, jd::dynamic_quant, jd::dynamic_quant_matmul, jd::eltwiseop, jd::gather, jd::groupnorm, jd::layernorm_ba, jd::layernormalized_spmm, jd::logsoftmax, jd::mha_dense, jd::slice, jd::softmax, jd::sparse_matmul, jd::transpose_matmul, jd::transpose_mha

Public Functions

inline kernel_proxy()
explicit kernel_proxy(const kernel_desc_proxy &kdp)
inline virtual ~kernel_proxy()
inline const jd::kernel_kind &kernel_kind() const
void execute(const std::vector<const void*> &rt_data) const
void execute(const exec_context_t &ctx) const
size_t get_workspace_size() const

Protected Functions

virtual bool create_proxy_object(std::shared_ptr<const kernel_t> &result_ref, const std::shared_ptr<const kernel_desc_t> &kd) override
class sparse_matmul_desc : public jd::kernel_desc_proxy
#include <interface.hpp>

Derived proxy class, interfacing to the real/cached sparse_matmul_desc_t.

Public Functions

inline sparse_matmul_desc()
inline explicit sparse_matmul_desc(const operator_desc &op_desc)
inline virtual ~sparse_matmul_desc()
class transpose_matmul_desc : public jd::kernel_desc_proxy

Public Functions

inline transpose_matmul_desc()
inline explicit transpose_matmul_desc(const operator_desc &op_desc)
inline virtual ~transpose_matmul_desc()
class dynamic_quant_matmul_desc : public jd::kernel_desc_proxy

Public Functions

inline dynamic_quant_matmul_desc()
inline explicit dynamic_quant_matmul_desc(const operator_desc &op_desc)
inline virtual ~dynamic_quant_matmul_desc()
class dynamic_quant_desc : public jd::kernel_desc_proxy

Public Functions

inline dynamic_quant_desc()
inline explicit dynamic_quant_desc(const operator_desc &op_desc)
inline virtual ~dynamic_quant_desc()
class eltwiseop_desc : public jd::kernel_desc_proxy

Public Functions

inline eltwiseop_desc()
inline explicit eltwiseop_desc(const operator_desc &op_desc)
inline virtual ~eltwiseop_desc()
class groupnorm_desc : public jd::kernel_desc_proxy

Public Functions

inline groupnorm_desc()
inline explicit groupnorm_desc(const operator_desc &op_desc)
inline virtual ~groupnorm_desc()
class layernorm_ba_desc : public jd::kernel_desc_proxy

Public Functions

inline layernorm_ba_desc()
inline explicit layernorm_ba_desc(const operator_desc &op_desc)
inline virtual ~layernorm_ba_desc()
class layernormalized_spmm_desc : public jd::kernel_desc_proxy

Public Functions

inline layernormalized_spmm_desc()
inline explicit layernormalized_spmm_desc(const operator_desc &op_desc)
inline virtual ~layernormalized_spmm_desc()
class gather_desc : public jd::kernel_desc_proxy

Public Functions

inline gather_desc()
inline explicit gather_desc(const operator_desc &op_desc)
inline virtual ~gather_desc()
class softmax_desc : public jd::kernel_desc_proxy

Public Functions

inline softmax_desc()
inline explicit softmax_desc(const operator_desc &op_desc)
inline virtual ~softmax_desc()
class logsoftmax_desc : public jd::kernel_desc_proxy

Public Functions

inline logsoftmax_desc()
inline explicit logsoftmax_desc(const operator_desc &op_desc)
inline virtual ~logsoftmax_desc()
class attention_desc : public jd::kernel_desc_proxy

Public Functions

inline attention_desc()
inline explicit attention_desc(const operator_desc &op_desc)
inline virtual ~attention_desc()
class transpose_mha_desc : public jd::kernel_desc_proxy

Public Functions

inline transpose_mha_desc()
inline explicit transpose_mha_desc(const operator_desc &op_desc)
inline virtual ~transpose_mha_desc()
class mha_dense_desc : public jd::kernel_desc_proxy

Public Functions

inline mha_dense_desc()
inline explicit mha_dense_desc(const operator_desc &op_desc)
inline virtual ~mha_dense_desc()
class slice_desc : public jd::kernel_desc_proxy

Public Functions

inline slice_desc()
inline explicit slice_desc(const operator_desc &op_desc)
inline virtual ~slice_desc()
class sparse_matmul : public jd::kernel_proxy
#include <interface.hpp>

Derived proxy class, interfacing to the real/cached sparse_matmul_t.

Public Functions

inline sparse_matmul()
inline explicit sparse_matmul(const kernel_desc_proxy &kdp)
inline virtual ~sparse_matmul()
class transpose_matmul : public jd::kernel_proxy

Public Functions

inline transpose_matmul()
inline explicit transpose_matmul(const kernel_desc_proxy &kdp)
inline virtual ~transpose_matmul()
class dynamic_quant_matmul : public jd::kernel_proxy

Public Functions

inline dynamic_quant_matmul()
inline explicit dynamic_quant_matmul(const kernel_desc_proxy &kdp)
inline virtual ~dynamic_quant_matmul()
class dynamic_quant : public jd::kernel_proxy

Public Functions

inline dynamic_quant()
inline explicit dynamic_quant(const kernel_desc_proxy &kdp)
inline virtual ~dynamic_quant()
class eltwiseop : public jd::kernel_proxy

Public Functions

inline eltwiseop()
inline explicit eltwiseop(const kernel_desc_proxy &kdp)
inline virtual ~eltwiseop()
class groupnorm : public jd::kernel_proxy

Public Functions

inline groupnorm()
inline explicit groupnorm(const kernel_desc_proxy &kdp)
inline virtual ~groupnorm()
class layernorm_ba : public jd::kernel_proxy

Public Functions

inline layernorm_ba()
inline explicit layernorm_ba(const kernel_desc_proxy &kdp)
inline virtual ~layernorm_ba()
class layernormalized_spmm : public jd::kernel_proxy

Public Functions

inline layernormalized_spmm()
inline explicit layernormalized_spmm(const kernel_desc_proxy &kdp)
inline virtual ~layernormalized_spmm()
class gather : public jd::kernel_proxy

Public Functions

inline gather()
inline explicit gather(const kernel_desc_proxy &kdp)
inline virtual ~gather()
class softmax : public jd::kernel_proxy

Public Functions

inline softmax()
inline explicit softmax(const kernel_desc_proxy &kdp)
inline virtual ~softmax()
class logsoftmax : public jd::kernel_proxy

Public Functions

inline logsoftmax()
inline explicit logsoftmax(const kernel_desc_proxy &kdp)
inline virtual ~logsoftmax()
class attention : public jd::kernel_proxy

Public Functions

inline attention()
inline explicit attention(const kernel_desc_proxy &kdp)
inline virtual ~attention()
class transpose_mha : public jd::kernel_proxy

Public Functions

inline transpose_mha()
inline explicit transpose_mha(const kernel_desc_proxy &kdp)
inline virtual ~transpose_mha()
class mha_dense : public jd::kernel_proxy

Public Functions

inline mha_dense()
inline explicit mha_dense(const kernel_desc_proxy &kdp)
inline virtual ~mha_dense()
class slice : public jd::kernel_proxy

Public Functions

inline slice()
inline explicit slice(const kernel_desc_proxy &kdp)
inline virtual ~slice()