clDNN
split.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 
18 #pragma once
19 #include "../C/split.h"
20 #include "primitive.hpp"
21 
22 namespace cldnn
23 {
30 
51 struct split : public primitive_base<split, CLDNN_PRIMITIVE_DESC(split)>
52 {
53  CLDNN_DECLATE_PRIMITIVE(split)
54 
55 
56  split(
60  const primitive_id& id,
61  const primitive_id& input,
62  const std::vector<std::pair<primitive_id, tensor> >& output_ids_offsets,
63  const padding& output_padding = padding()
64  )
66  , output_ids(_output_ids.cpp_ids)
67  , output_offsets(extract_tensor_vector(output_ids_offsets))
68  , _output_ids(extract_primitive_vector(output_ids_offsets))
69  , _output_offsets(tensor_vector_to_cldnn_vector(output_offsets))
70  {
71  }
72 
74  split(const dto* dto)
76  , output_ids(_output_ids.cpp_ids)
77  , output_offsets(tensor_arr_to_vector(dto->output_offsets))
78  , _output_ids(dto->output_ids)
79  , _output_offsets(tensor_arr_to_cldnn_vector(dto->output_offsets))
80  {
81  }
82 
86  std::vector<tensor> output_offsets;
87 
88 protected:
89  primitive_id_arr _output_ids;
90  std::vector<cldnn_tensor> _output_offsets;
91 
92  void update_dto(dto& dto) const override
93  {
94  dto.output_ids = _output_ids.ref();
95  dto.output_offsets = tensor_vector_to_arr(_output_offsets);
96  }
97 
98  static std::vector<primitive_id> extract_primitive_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
99  {
100  std::vector<primitive_id> res;
101  for (auto &stor_pair : stor)
102  res.push_back(stor_pair.first);
103 
104  return res;
105  }
106 
107  static std::vector<tensor> extract_tensor_vector(const std::vector<std::pair<primitive_id, tensor> >& stor)
108  {
109  std::vector<tensor> res;
110  for (auto &stor_pair : stor)
111  res.push_back(stor_pair.second);
112 
113  return res;
114  }
115 
116  static std::vector<tensor> tensor_arr_to_vector(const cldnn_tensor_arr& arr)
117  {
118  std::vector<tensor> result(arr.size);
119  for (size_t i = 0; i < arr.size; i++)
120  result[i] = arr.data[i];
121 
122  return result;
123  }
124 
125  static std::vector<cldnn_tensor> tensor_arr_to_cldnn_vector(const cldnn_tensor_arr& arr)
126  {
127  std::vector<cldnn_tensor> result(arr.size);
128  for (size_t i = 0; i < arr.size; i++)
129  result[i] = arr.data[i];
130 
131  return result;
132  }
133 
134  static std::vector<cldnn_tensor> tensor_vector_to_cldnn_vector(const std::vector<tensor>& stor)
135  {
136  std::vector<cldnn_tensor> res;
137  res.resize(stor.size());
138  for (size_t i = 0; i < stor.size(); ++i)
139  res[i] = stor[i];
140 
141  return res;
142  }
143 
144 };
148 }
Represents data padding information.
Definition: layout.hpp:125
Represents reference to an array of tensor.
Definition: cldnn.h:322
std::vector< tensor > output_offsets
Array of tensors with offsets.
Definition: split.hpp:86
split(const primitive_id &id, const primitive_id &input, const std::vector< std::pair< primitive_id, tensor > > &output_ids_offsets, const padding &output_padding=padding())
Constructs split primitive.
Definition: split.hpp:59
fixed_size_vector_ref output_ids
List of output_ids.
Definition: split.hpp:84
size_t size
Size (in tensor) of the array.
Definition: cldnn.h:325
const cldnn_tensor * data
Pointer to tensor array.
Definition: cldnn.h:324
Performs split operation on input.
Definition: split.hpp:51
split(const dto *dto)
Constructs a copy from C API cldnn_split_desc.
Definition: split.hpp:74
cldnn_primitive_id_arr output_ids
List of output_ids.
Definition: split.h:56
std::string primitive_id
Unique id of a primitive within a topology.
Definition: primitive.hpp:42
fixed_size_vector_ref input
List of ids of input primitives.
Definition: primitive.hpp:146
base class for all primitives implementations.
Definition: primitive.hpp:190
cldnn_tensor_arr output_offsets
Array of tensors with offsets.
Definition: split.h:58
padding output_padding
Requested output padding.
Definition: primitive.hpp:149
Performs split operation on input.
Definition: split.h:54
Initialize fields common for all primitives.
Definition: primitive.hpp:64