clDNN
proposal.hpp
1 /*
2 // Copyright (c) 2017 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 
18 #pragma once
19 
20 #include <vector>
21 
22 #include "../C/proposal.h"
23 #include "primitive.hpp"
24 
25 namespace cldnn
26 {
33 
34 struct proposal : public primitive_base<proposal, CLDNN_PRIMITIVE_DESC(proposal)>
35 {
36  CLDNN_DECLATE_PRIMITIVE(proposal)
37 
38  proposal(
39  const primitive_id& id,
40  const primitive_id& cls_scores,
41  const primitive_id& bbox_pred,
42  const primitive_id& image_info,
43  int max_proposals,
44  float iou_threshold,
45  int min_bbox_size,
46  int feature_stride,
47  int pre_nms_topn,
48  int post_nms_topn,
49  const std::vector<float>& ratios_param,
50  const std::vector<float>& scales_param,
51  const padding& output_padding = padding()
52  )
53  : primitive_base(id, {cls_scores, bbox_pred, image_info}, output_padding),
54  max_proposals(max_proposals),
55  iou_threshold(iou_threshold),
56  min_bbox_size(min_bbox_size),
57  feature_stride(feature_stride),
58  pre_nms_topn(pre_nms_topn),
59  post_nms_topn(post_nms_topn),
60  ratios(ratios_param),
61  scales(scales_param)
62  {
63  }
64 
65  proposal(const dto* dto) :
67  max_proposals(dto->max_proposals),
68  iou_threshold(dto->iou_threshold),
69  min_bbox_size(dto->min_bbox_size),
70  feature_stride(dto->feature_stride),
71  pre_nms_topn(dto->pre_nms_topn),
72  post_nms_topn(dto->post_nms_topn),
73  ratios(float_arr_to_vector(dto->ratios)),
74  scales(float_arr_to_vector(dto->scales))
75  {
76  }
77 
78  int max_proposals;
79  float iou_threshold;
80  int min_bbox_size;
81  int feature_stride;
82  int pre_nms_topn;
83  int post_nms_topn;
84  std::vector<float> ratios;
85  std::vector<float> scales;
86 
87 protected:
88  void update_dto(dto& dto) const override
89  {
90  dto.max_proposals = max_proposals;
91  dto.iou_threshold = iou_threshold;
92  dto.min_bbox_size = min_bbox_size;
93  dto.feature_stride = feature_stride;
94  dto.pre_nms_topn = pre_nms_topn;
95  dto.post_nms_topn = post_nms_topn;
96  dto.ratios = float_vector_to_arr(ratios);
97  dto.scales = float_vector_to_arr(scales);
98  }
99 };
100 
104 }
Represents data padding information.
Definition: layout.hpp:125
std::string primitive_id
Unique id of a primitive within a topology.
Definition: primitive.hpp:42
base class for all primitives implementations.
Definition: primitive.hpp:190
padding output_padding
Requested output padding.
Definition: primitive.hpp:149