clDNN
primitive.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 
18 #pragma once
19 
20 #include "cldnn_defs.h"
21 #include "compounds.h"
22 #include "layout.hpp"
23 
24 #include <algorithm>
25 #include <string>
26 #include <vector>
27 #include <iostream>
28 
29 namespace cldnn
30 {
33 
36 
42 using primitive_id = std::string;
43 
45 template<class PType>
46 typename PType::dto* as_dto(CLDNN_PRIMITIVE_DESC(primitive)* dto)
47 {
48  if (dto->type != PType::type_id()) throw std::invalid_argument("type");
49  return reinterpret_cast<typename PType::dto*>(dto);
50 }
51 
53 template<class PType>
54 const typename PType::dto* as_dto(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
55 {
56  if (dto->type != PType::type_id()) throw std::invalid_argument("type");
57  return reinterpret_cast<const typename PType::dto*>(dto);
58 }
59 
61 struct primitive
62 {
65  {
66  private:
67  std::vector<primitive_id>& vref;
68 
69  public:
70  fixed_size_vector_ref(std::vector<primitive_id>& ref) : vref(ref)
71  {}
72 
73  auto size() const -> decltype(vref.size()) { return vref.size(); }
74  auto begin() const -> decltype(vref.begin()) { return vref.begin(); }
75  auto end() const -> decltype(vref.end()) { return vref.end(); }
76  auto cbegin() const -> decltype(vref.cbegin()) { return vref.cbegin(); }
77  auto cned() const -> decltype(vref.cend()) { return vref.cend(); }
78 
79  primitive_id& operator[](size_t idx) { return vref[idx]; }
80  primitive_id const& operator[](size_t idx) const { return vref[idx]; }
81 
82  primitive_id& at(size_t idx) { return vref.at(idx); }
83  primitive_id const& at(size_t idx) const { return vref.at(idx); }
84 
85  primitive_id* data() { return vref.data(); }
86  const primitive_id* data() const { return vref.data(); }
87 
88  const std::vector<primitive_id>& ref() const { return vref; }
89  };
90 public:
91  primitive(
92  const primitive_type_id& type,
93  const primitive_id& id,
94  const std::vector<primitive_id>& input,
95  const padding& output_padding = padding()
96  )
97  :type(type), id(id), input(_input.cpp_ids), output_padding(output_padding), _input(input)
98  {}
99 
101  primitive(const CLDNN_PRIMITIVE_DESC(primitive)* dto)
102  :type(dto->type), id(dto->id), input(_input.cpp_ids), output_padding(dto->output_padding), _input(dto->input)
103  {}
104 
105  virtual ~primitive() = default;
106 
110  virtual const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const = 0;
111 
113  std::vector<std::reference_wrapper<primitive_id>> dependecies()
114  {
115  std::vector<std::reference_wrapper<primitive_id>> result;
116  auto&& deps = get_dependencies();
117 
118  result.reserve(_input.size() + deps.size());
119  for (auto& pid : _input.cpp_ids)
120  result.push_back(std::ref(pid));
121  for (auto& pid : deps)
122  result.push_back(std::ref(const_cast<primitive_id&>(pid.get())));
123 
124  return result;
125  }
126 
128  std::vector<primitive_id> dependecies() const
129  {
130  auto result = input.ref();
131  auto deps = get_dependencies();
132  result.insert(result.end(), deps.begin(), deps.end());
133  return result;
134  }
135 
137  operator primitive_id() const { return id; }
138 
141 
144 
147 
150 
151 protected:
153  {
154  primitive_id_arr(std::vector<primitive_id> const& vec) : cpp_ids(vec)
155  {}
156 
157  primitive_id_arr(std::vector<primitive_id>&& vec) : cpp_ids(std::move(vec))
158  {}
159 
160  //create from C API id array
162  {
163  cpp_ids.resize(c_id_arr.size);
164  for (size_t i = 0; i < c_id_arr.size; ++i)
165  cpp_ids[i] = c_id_arr.data[i];
166  }
167 
168  std::vector<primitive_id> cpp_ids;
169  mutable std::vector<cldnn_primitive_id> c_ids;
170  //get C API id array
171  auto ref() const -> decltype(cldnn_primitive_id_arr{c_ids.data(), c_ids.size()})
172  {
173  c_ids.resize(cpp_ids.size());
174  for (size_t i = 0; i < cpp_ids.size(); ++i)
175  c_ids[i] = cpp_ids[i].c_str();
176 
177  return cldnn_primitive_id_arr{ c_ids.data(), c_ids.size() };
178  }
179 
180  size_t size() const { return cpp_ids.size(); }
181  };
182 
183  primitive_id_arr _input;
184 
185  virtual std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const { return{}; }
186 };
187 
189 template<class PType, class DTO>
190 class primitive_base : public primitive
191 {
192 public:
194  const CLDNN_PRIMITIVE_DESC(primitive)* get_dto() const override
195  {
196  //update common dto fields
197  _dto.id = id.c_str();
198  _dto.type = type;
199  _dto.input = _input.ref();
200  _dto.output_padding = output_padding;
201 
202  //call abstract method to update primitive-specific fields
203  update_dto(_dto);
204  return reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(&_dto);
205  }
206 
207 protected:
208  explicit primitive_base(
209  const primitive_id& id,
210  const std::vector<primitive_id>& input,
211  const padding& output_padding = padding())
212  : primitive(PType::type_id(), id, input, output_padding)
213  {}
214 
215  primitive_base(const DTO* dto)
216  : primitive(reinterpret_cast<const CLDNN_PRIMITIVE_DESC(primitive)*>(dto))
217  {
218  if (dto->type != PType::type_id())
219  throw std::invalid_argument("DTO type mismatch");
220  }
221 
222 private:
223  mutable DTO _dto;
224 
225  virtual void update_dto(DTO& dto) const = 0;
226 };
227 
228 #define CLDNN_DEFINE_TYPE_ID(PType) static primitive_type_id type_id()\
229  {\
230  return check_status<primitive_type_id>( #PType " type id failed", [](status_t* status)\
231  {\
232  return cldnn_##PType##_type_id(status);\
233  });\
234  }
235 
236 #define CLDNN_DECLATE_PRIMITIVE(PType) typedef CLDNN_PRIMITIVE_DESC(PType) dto;\
237  CLDNN_DEFINE_TYPE_ID(PType)
238 }
Base class of network primitive description.
Definition: primitive.hpp:61
virtual const cldnn_primitive_desc * get_dto() const =0
Requested output padding.
const char * cldnn_primitive_id
Unique id of a primitive within a topology.
Definition: cldnn.h:332
cldnn_primitive_id primitive_id_ref
C API compatible unique id of a primitive within a topology.
Definition: primitive.hpp:40
Represents data padding information.
Definition: layout.hpp:125
const cldnn_primitive_desc * get_dto() const override
Returns pointer to a C API primitive descriptor casted to cldnn_primitive_desc.
Definition: primitive.hpp:194
cldnn_primitive_type_id primitive_type_id
Globally unique primitive type id.
Definition: primitive.hpp:38
const cldnn_primitive_id * data
Pointer to ids array.
Definition: cldnn.h:337
const PType::dto * as_dto(const cldnn_primitive_desc *dto)
Dynamic cast to specified primitive description type.
Definition: primitive.hpp:54
size_t size
Number of ids in the array.
Definition: cldnn.h:338
Provides input data to topology.
Definition: data.hpp:36
primitive(const cldnn_primitive_desc *dto)
Constructs a copy from basic C API cldnn_primitive_desc.
Definition: primitive.hpp:101
const primitive_type_id type
Primitive&#39;s type id.
Definition: primitive.hpp:140
const primitive_id id
Primitive&#39;s id.
Definition: primitive.hpp:143
Represents reference to an array of primitive ids.
Definition: cldnn.h:335
std::vector< primitive_id > dependecies() const
Returns copy of all primitive ids on which this primitive depends - inputs, weights, biases, etc.
Definition: primitive.hpp:128
std::string primitive_id
Unique id of a primitive within a topology.
Definition: primitive.hpp:42
fixed_size_vector_ref input
List of ids of input primitives.
Definition: primitive.hpp:146
base class for all primitives implementations.
Definition: primitive.hpp:190
padding output_padding
Requested output padding.
Definition: primitive.hpp:149
Initialize fields common for all primitives.
Definition: primitive.hpp:64
const struct cldnn_primitive_type * cldnn_primitive_type_id
Globally unique primitive&#39;s type id.
Definition: cldnn.h:329
data(const primitive_id &id, const memory &mem)
Constructs data primitive.
Definition: data.hpp:44
std::vector< std::reference_wrapper< primitive_id > > dependecies()
Returns references to all primitive ids on which this primitive depends - inputs, weights...
Definition: primitive.hpp:113