clDNN
batch_norm.hpp
1 /*
2 // Copyright (c) 2016 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 */
16 
18 #pragma once
19 #include "../C/batch_norm.h"
20 #include "primitive.hpp"
21 
22 namespace cldnn
23 {
30 
39 
40 struct batch_norm : public primitive_base<batch_norm, CLDNN_PRIMITIVE_DESC(batch_norm)>
41 {
42  CLDNN_DECLATE_PRIMITIVE(batch_norm)
43 
44 
45  batch_norm(
51  const primitive_id& id,
52  const primitive_id& input,
53  const primitive_id& mean,
54  const primitive_id& variance,
55  float epsilon,
56  const padding& output_padding = padding()
57  )
59  , mean(mean)
61  , epsilon(epsilon)
62  {
63  }
64 
66  batch_norm(const dto* dto)
68  , mean(dto->mean)
70  , epsilon(dto->epsilon)
71  {
72  }
73 
79  float epsilon;
80 
81 protected:
82  std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override { return{ mean, variance }; }
83 
84  void update_dto(dto& dto) const override
85  {
86  dto.mean = mean.c_str();
87  dto.variance = variance.c_str();
88  dto.epsilon = epsilon;
89  }
90 };
94 }
primitive_id variance
Primitive id containing variance.
Definition: batch_norm.hpp:77
Represents data padding information.
Definition: layout.hpp:125
Batch normalization primitive.
Definition: batch_norm.hpp:40
batch_norm(const dto *dto)
Constructs a copy from C API cldnn_batch_norm_desc.
Definition: batch_norm.hpp:66
float epsilon
Epsilon.
Definition: batch_norm.hpp:79
batch_norm(const primitive_id &id, const primitive_id &input, const primitive_id &mean, const primitive_id &variance, float epsilon, const padding &output_padding=padding())
Constructs batch normalization primitive.
Definition: batch_norm.hpp:50
primitive_id mean
Primitive id containing mean data.
Definition: batch_norm.hpp:75
std::string primitive_id
Unique id of a primitive within a topology.
Definition: primitive.hpp:42
fixed_size_vector_ref input
List of ids of input primitives.
Definition: primitive.hpp:146
base class for all primitives implementations.
Definition: primitive.hpp:190
padding output_padding
Requested output padding.
Definition: primitive.hpp:149
Batch normalization primitive.
Definition: batch_norm.h:42