BatchNorm< InputDataType, OutputDataType > Class Template Reference

Declaration of the Batch Normalization layer class. More...

Public Member Functions

 BatchNorm ()
 Create the BatchNorm object. More...

 
 BatchNorm (const size_t size, const double eps=1e-8, const bool average=true, const double momentum=0.1)
 Create the BatchNorm layer object for a specified number of input units. More...

 
bool Average () const
 Get the average parameter. More...

 
template
<
typename
eT
>
void Backward (const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
 Backward pass through the layer. More...

 
OutputDataType const & Delta () const
 Get the delta. More...

 
OutputDataType & Delta ()
 Modify the delta. More...

 
bool Deterministic () const
 Get the value of deterministic parameter. More...

 
bool & Deterministic ()
 Modify the value of deterministic parameter. More...

 
double Epsilon () const
 Get the epsilon value. More...

 
template
<
typename
eT
>
void Forward (const arma::Mat< eT > &input, arma::Mat< eT > &output)
 Forward pass of the Batch Normalization layer. More...

 
template
<
typename
eT
>
void Gradient (const arma::Mat< eT > &input, const arma::Mat< eT > &error, arma::Mat< eT > &gradient)
 Calculate the gradient using the output delta and the input activations. More...

 
OutputDataType const & Gradient () const
 Get the gradient. More...

 
OutputDataType & Gradient ()
 Modify the gradient. More...

 
size_t InputSize () const
 Get the number of input units / channels. More...

 
double Momentum () const
 Get the momentum value. More...

 
OutputDataType const & OutputParameter () const
 Get the output parameter. More...

 
OutputDataType & OutputParameter ()
 Modify the output parameter. More...

 
OutputDataType const & Parameters () const
 Get the parameters. More...

 
OutputDataType & Parameters ()
 Modify the parameters. More...

 
void Reset ()
 Reset the layer parameters. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const uint32_t)
 Serialize the layer. More...

 
OutputDataType const & TrainingMean () const
 Get the mean over the training data. More...

 
OutputDataType & TrainingMean ()
 Modify the mean over the training data. More...

 
OutputDataType const & TrainingVariance () const
 Get the variance over the training data. More...

 
OutputDataType & TrainingVariance ()
 Modify the variance over the training data. More...

 
size_t WeightSize () const
 Get size of weights. More...

 

Detailed Description


template
<
typename
InputDataType
=
arma::mat
,
typename
OutputDataType
=
arma::mat
>

class mlpack::ann::BatchNorm< InputDataType, OutputDataType >

Declaration of the Batch Normalization layer class.

The layer transforms the input data into zero mean and unit variance and then scales and shifts the data by parameters, gamma and beta respectively. These parameters are learnt by the network.

If deterministic is false (training), the mean and variance over the batch is calculated and the data is normalized. If it is set to true (testing) then the mean and variance accrued over the training set is used.

For more information, refer to the following paper,

@article{Ioffe15,
author = {Sergey Ioffe and
Christian Szegedy},
title = {Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariate Shift},
journal = {CoRR},
volume = {abs/1502.03167},
year = {2015},
url = {http://arxiv.org/abs/1502.03167},
eprint = {1502.03167},
}
Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Definition at line 56 of file batch_norm.hpp.

Constructor & Destructor Documentation

◆ BatchNorm() [1/2]

BatchNorm ( )

Create the BatchNorm object.

◆ BatchNorm() [2/2]

BatchNorm ( const size_t  size,
const double  eps = 1e-8,
const bool  average = true,
const double  momentum = 0.1 
)

Create the BatchNorm layer object for a specified number of input units.

Parameters
sizeThe number of input units / channels.
epsThe epsilon added to variance to ensure numerical stability.
averageBoolean to determine whether cumulative average is used for updating the parameters or momentum is used.
momentumParameter used to to update the running mean and variance.

Member Function Documentation

◆ Average()

bool Average ( ) const
inline

Get the average parameter.

Definition at line 161 of file batch_norm.hpp.

◆ Backward()

void Backward ( const arma::Mat< eT > &  input,
const arma::Mat< eT > &  gy,
arma::Mat< eT > &  g 
)

Backward pass through the layer.

Parameters
inputThe input activations
gyThe backpropagated error.
gThe calculated gradient.

◆ Delta() [1/2]

OutputDataType const& Delta ( ) const
inline

Get the delta.

Definition at line 127 of file batch_norm.hpp.

◆ Delta() [2/2]

OutputDataType& Delta ( )
inline

Modify the delta.

Definition at line 129 of file batch_norm.hpp.

◆ Deterministic() [1/2]

bool Deterministic ( ) const
inline

Get the value of deterministic parameter.

Definition at line 137 of file batch_norm.hpp.

◆ Deterministic() [2/2]

bool& Deterministic ( )
inline

Modify the value of deterministic parameter.

Definition at line 139 of file batch_norm.hpp.

◆ Epsilon()

double Epsilon ( ) const
inline

Get the epsilon value.

Definition at line 155 of file batch_norm.hpp.

◆ Forward()

void Forward ( const arma::Mat< eT > &  input,
arma::Mat< eT > &  output 
)

Forward pass of the Batch Normalization layer.

Transforms the input data into zero mean and unit variance, scales the data by a factor gamma and shifts it by beta.

Parameters
inputInput data for the layer
outputResulting output activations.

◆ Gradient() [1/3]

void Gradient ( const arma::Mat< eT > &  input,
const arma::Mat< eT > &  error,
arma::Mat< eT > &  gradient 
)

Calculate the gradient using the output delta and the input activations.

Parameters
inputThe input activations
errorThe calculated error
gradientThe calculated gradient.

◆ Gradient() [2/3]

OutputDataType const& Gradient ( ) const
inline

Get the gradient.

Definition at line 132 of file batch_norm.hpp.

◆ Gradient() [3/3]

OutputDataType& Gradient ( )
inline

Modify the gradient.

Definition at line 134 of file batch_norm.hpp.

◆ InputSize()

size_t InputSize ( ) const
inline

Get the number of input units / channels.

Definition at line 152 of file batch_norm.hpp.

◆ Momentum()

double Momentum ( ) const
inline

Get the momentum value.

Definition at line 158 of file batch_norm.hpp.

◆ OutputParameter() [1/2]

OutputDataType const& OutputParameter ( ) const
inline

Get the output parameter.

Definition at line 122 of file batch_norm.hpp.

◆ OutputParameter() [2/2]

OutputDataType& OutputParameter ( )
inline

Modify the output parameter.

Definition at line 124 of file batch_norm.hpp.

◆ Parameters() [1/2]

OutputDataType const& Parameters ( ) const
inline

Get the parameters.

Definition at line 117 of file batch_norm.hpp.

◆ Parameters() [2/2]

OutputDataType& Parameters ( )
inline

Modify the parameters.

Definition at line 119 of file batch_norm.hpp.

◆ Reset()

void Reset ( )

Reset the layer parameters.

◆ serialize()

void serialize ( Archive &  ar,
const uint32_t   
)

Serialize the layer.

Referenced by BatchNorm< InputDataType, OutputDataType >::WeightSize().

◆ TrainingMean() [1/2]

OutputDataType const& TrainingMean ( ) const
inline

Get the mean over the training data.

Definition at line 142 of file batch_norm.hpp.

◆ TrainingMean() [2/2]

OutputDataType& TrainingMean ( )
inline

Modify the mean over the training data.

Definition at line 144 of file batch_norm.hpp.

◆ TrainingVariance() [1/2]

OutputDataType const& TrainingVariance ( ) const
inline

Get the variance over the training data.

Definition at line 147 of file batch_norm.hpp.

◆ TrainingVariance() [2/2]

OutputDataType& TrainingVariance ( )
inline

Modify the variance over the training data.

Definition at line 149 of file batch_norm.hpp.

◆ WeightSize()

size_t WeightSize ( ) const
inline

Get size of weights.

Definition at line 164 of file batch_norm.hpp.

References BatchNorm< InputDataType, OutputDataType >::serialize().


The documentation for this class was generated from the following file:
  • /home/ryan/src/mlpack.org/_src/mlpack-git/src/mlpack/methods/ann/layer/batch_norm.hpp