clDNN
cldnn::batch_norm Struct Reference

Batch normalization primitive. More...

#include <batch_norm.hpp>

Inheritance diagram for cldnn::batch_norm:
Collaboration diagram for cldnn::batch_norm:

Public Types

typedef cldnn_batch_norm_desc dto
 

Public Member Functions

 batch_norm (const primitive_id &id, const primitive_id &input, const primitive_id &mean, const primitive_id &variance, float epsilon, const padding &output_padding=padding())
 Constructs batch normalization primitive. More...
 
 batch_norm (const dto *dto)
 Constructs a copy from C API cldnn_batch_norm_desc.
 
- Public Member Functions inherited from cldnn::primitive_base< batch_norm, cldnn_batch_norm_desc >
const cldnn_primitive_descget_dto () const override
 Returns pointer to a C API primitive descriptor casted to cldnn_primitive_desc.
 
- Public Member Functions inherited from cldnn::primitive
 primitive (const primitive_type_id &type, const primitive_id &id, const std::vector< primitive_id > &input, const padding &output_padding=padding())
 
 primitive (const cldnn_primitive_desc *dto)
 Constructs a copy from basic C API cldnn_primitive_desc.
 
std::vector< std::reference_wrapper< primitive_id > > dependecies ()
 Returns references to all primitive ids on which this primitive depends - inputs, weights, biases, etc.
 
std::vector< primitive_iddependecies () const
 Returns copy of all primitive ids on which this primitive depends - inputs, weights, biases, etc.
 
 operator primitive_id () const
 Implicit conversion to primiitive id.
 

Static Public Member Functions

static primitive_type_id type_id ()
 

Public Attributes

primitive_id mean
 Primitive id containing mean data.
 
primitive_id variance
 Primitive id containing variance.
 
float epsilon
 Epsilon.
 
- Public Attributes inherited from cldnn::primitive
const primitive_type_id type
 Primitive's type id.
 
const primitive_id id
 Primitive's id.
 
fixed_size_vector_ref input
 List of ids of input primitives.
 
padding output_padding
 Requested output padding.
 

Protected Member Functions

std::vector< std::reference_wrapper< const primitive_id > > get_dependencies () const override
 
void update_dto (dto &dto) const override
 
- Protected Member Functions inherited from cldnn::primitive_base< batch_norm, cldnn_batch_norm_desc >
 primitive_base (const primitive_id &id, const std::vector< primitive_id > &input, const padding &output_padding=padding())
 
 primitive_base (const cldnn_batch_norm_desc *dto)
 

Additional Inherited Members

- Protected Attributes inherited from cldnn::primitive
primitive_id_arr _input
 

Detailed Description

Batch normalization primitive.

Performs batch normalization as discribed in "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" by Ioffe, Szegedy
See: http://arxiv.org/abs/1502.03167

Algorithm:
global stats can be computed as:
out[i] = in[i] - mean[b] / sqrt(variance[b] + epsilon)

Definition at line 40 of file batch_norm.hpp.

Constructor & Destructor Documentation

◆ batch_norm()

cldnn::batch_norm::batch_norm ( const primitive_id id,
const primitive_id input,
const primitive_id mean,
const primitive_id variance,
float  epsilon,
const padding output_padding = padding() 
)
inline

Constructs batch normalization primitive.

Parameters
idThis primitive id.
inputInput primitive id.
meanPrimitive id containing mean data.
variancePrimitive id containing variance.
epsilonEpsilon.

Definition at line 50 of file batch_norm.hpp.


The documentation for this struct was generated from the following file: