20 #include "cldnn_defs.h" 21 #include "compounds.h" 48 if (dto->
type != PType::type_id())
throw std::invalid_argument(
"type");
49 return reinterpret_cast<typename PType::dto*
>(dto);
56 if (dto->
type != PType::type_id())
throw std::invalid_argument(
"type");
57 return reinterpret_cast<const typename PType::dto*
>(dto);
67 std::vector<primitive_id>& vref;
73 auto size()
const -> decltype(vref.size()) {
return vref.size(); }
74 auto begin()
const -> decltype(vref.begin()) {
return vref.begin(); }
75 auto end()
const -> decltype(vref.end()) {
return vref.end(); }
76 auto cbegin()
const -> decltype(vref.cbegin()) {
return vref.cbegin(); }
77 auto cned()
const -> decltype(vref.cend()) {
return vref.cend(); }
79 primitive_id& operator[](
size_t idx) {
return vref[idx]; }
80 primitive_id const& operator[](
size_t idx)
const {
return vref[idx]; }
83 primitive_id const& at(
size_t idx)
const {
return vref.at(idx); }
88 const std::vector<primitive_id>& ref()
const {
return vref; }
94 const std::vector<primitive_id>&
input,
115 std::vector<std::reference_wrapper<primitive_id>> result;
116 auto&& deps = get_dependencies();
118 result.reserve(_input.size() + deps.size());
119 for (
auto& pid : _input.cpp_ids)
120 result.push_back(std::ref(pid));
121 for (
auto& pid : deps)
122 result.push_back(std::ref(const_cast<primitive_id&>(pid.get())));
130 auto result =
input.ref();
131 auto deps = get_dependencies();
132 result.insert(result.end(), deps.begin(), deps.end());
163 cpp_ids.resize(c_id_arr.
size);
164 for (
size_t i = 0; i < c_id_arr.
size; ++i)
165 cpp_ids[i] = c_id_arr.
data[i];
168 std::vector<primitive_id> cpp_ids;
169 mutable std::vector<cldnn_primitive_id> c_ids;
173 c_ids.resize(cpp_ids.size());
174 for (
size_t i = 0; i < cpp_ids.size(); ++i)
175 c_ids[i] = cpp_ids[i].c_str();
180 size_t size()
const {
return cpp_ids.size(); }
185 virtual std::vector<std::reference_wrapper<const primitive_id>> get_dependencies()
const {
return{}; }
189 template<
class PType,
class DTO>
197 _dto.id =
id.c_str();
199 _dto.input = _input.ref();
204 return reinterpret_cast<const CLDNN_PRIMITIVE_DESC(
primitive)*
>(&_dto);
210 const std::vector<primitive_id>&
input,
215 primitive_base(
const DTO* dto)
216 : primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
218 if (dto->type != PType::type_id())
219 throw std::invalid_argument(
"DTO type mismatch");
225 virtual void update_dto(DTO& dto)
const = 0;
228 #define CLDNN_DEFINE_TYPE_ID(PType) static primitive_type_id type_id()\ 230 return check_status<primitive_type_id>( #PType " type id failed", [](status_t* status)\ 232 return cldnn_##PType##_type_id(status);\ 236 #define CLDNN_DECLATE_PRIMITIVE(PType) typedef CLDNN_PRIMITIVE_DESC(PType) dto;\ 237 CLDNN_DEFINE_TYPE_ID(PType) Base class of network primitive description.
virtual const cldnn_primitive_desc * get_dto() const =0
Requested output padding.
const char * cldnn_primitive_id
Unique id of a primitive within a topology.
cldnn_primitive_id primitive_id_ref
C API compatible unique id of a primitive within a topology.
Represents data padding information.
const cldnn_primitive_desc * get_dto() const override
Returns pointer to a C API primitive descriptor casted to cldnn_primitive_desc.
cldnn_primitive_type_id primitive_type_id
Globally unique primitive type id.
const cldnn_primitive_id * data
Pointer to ids array.
const PType::dto * as_dto(const cldnn_primitive_desc *dto)
Dynamic cast to specified primitive description type.
size_t size
Number of ids in the array.
Provides input data to topology.
primitive(const cldnn_primitive_desc *dto)
Constructs a copy from basic C API cldnn_primitive_desc.
const primitive_type_id type
Primitive's type id.
const primitive_id id
Primitive's id.
Represents reference to an array of primitive ids.
std::vector< primitive_id > dependecies() const
Returns copy of all primitive ids on which this primitive depends - inputs, weights, biases, etc.
std::string primitive_id
Unique id of a primitive within a topology.
fixed_size_vector_ref input
List of ids of input primitives.
base class for all primitives implementations.
padding output_padding
Requested output padding.
Initialize fields common for all primitives.
const struct cldnn_primitive_type * cldnn_primitive_type_id
Globally unique primitive's type id.
data(const primitive_id &id, const memory &mem)
Constructs data primitive.
std::vector< std::reference_wrapper< primitive_id > > dependecies()
Returns references to all primitive ids on which this primitive depends - inputs, weights...