batch_norm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class BatchNorm
57 {
58  public:
60  BatchNorm();
61 
71  BatchNorm(const size_t size,
72  const double eps = 1e-8,
73  const bool average = true,
74  const double momentum = 0.1);
75 
79  void Reset();
80 
89  template<typename eT>
90  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
91 
99  template<typename eT>
100  void Backward(const arma::Mat<eT>& input,
101  const arma::Mat<eT>& gy,
102  arma::Mat<eT>& g);
103 
111  template<typename eT>
112  void Gradient(const arma::Mat<eT>& input,
113  const arma::Mat<eT>& error,
114  arma::Mat<eT>& gradient);
115 
117  OutputDataType const& Parameters() const { return weights; }
119  OutputDataType& Parameters() { return weights; }
120 
122  OutputDataType const& OutputParameter() const { return outputParameter; }
124  OutputDataType& OutputParameter() { return outputParameter; }
125 
127  OutputDataType const& Delta() const { return delta; }
129  OutputDataType& Delta() { return delta; }
130 
132  OutputDataType const& Gradient() const { return gradient; }
134  OutputDataType& Gradient() { return gradient; }
135 
137  bool Deterministic() const { return deterministic; }
139  bool& Deterministic() { return deterministic; }
140 
142  OutputDataType const& TrainingMean() const { return runningMean; }
144  OutputDataType& TrainingMean() { return runningMean; }
145 
147  OutputDataType const& TrainingVariance() const { return runningVariance; }
149  OutputDataType& TrainingVariance() { return runningVariance; }
150 
152  size_t InputSize() const { return size; }
153 
155  double Epsilon() const { return eps; }
156 
158  double Momentum() const { return momentum; }
159 
161  bool Average() const { return average; }
162 
164  size_t WeightSize() const { return 2 * size; }
165 
169  template<typename Archive>
170  void serialize(Archive& ar, const uint32_t /* version */);
171 
172  private:
174  size_t size;
175 
177  double eps;
178 
181  bool average;
182 
184  double momentum;
185 
187  bool loading;
188 
190  OutputDataType gamma;
191 
193  OutputDataType beta;
194 
196  OutputDataType mean;
197 
199  OutputDataType variance;
200 
202  OutputDataType weights;
203 
208  bool deterministic;
209 
211  size_t count;
212 
215  double averageFactor;
216 
218  OutputDataType runningMean;
219 
221  OutputDataType runningVariance;
222 
224  OutputDataType gradient;
225 
227  OutputDataType delta;
228 
230  OutputDataType outputParameter;
231 
233  arma::cube normalized;
234 
236  arma::cube inputMean;
237 }; // class BatchNorm
238 
239 } // namespace ann
240 } // namespace mlpack
241 
242 // Include the implementation.
243 #include "batch_norm_impl.hpp"
244 
245 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: batch_norm.hpp:134
OutputDataType const & TrainingMean() const
Get the mean over the training data.
Definition: batch_norm.hpp:142
Linear algebra utility functions, generally performed on matrices or vectors.
OutputDataType & TrainingVariance()
Modify the variance over the training data.
Definition: batch_norm.hpp:149
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & Delta()
Modify the delta.
Definition: batch_norm.hpp:129
OutputDataType const & TrainingVariance() const
Get the variance over the training data.
Definition: batch_norm.hpp:147
bool Deterministic() const
Get the value of deterministic parameter.
Definition: batch_norm.hpp:137
bool Average() const
Get the average parameter.
Definition: batch_norm.hpp:161
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: batch_norm.hpp:122
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of the Batch Normalization layer.
void Reset()
Reset the layer parameters.
OutputDataType & Parameters()
Modify the parameters.
Definition: batch_norm.hpp:119
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: batch_norm.hpp:139
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
BatchNorm()
Create the BatchNorm object.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: batch_norm.hpp:124
OutputDataType const & Parameters() const
Get the parameters.
Definition: batch_norm.hpp:117
size_t InputSize() const
Get the number of input units / channels.
Definition: batch_norm.hpp:152
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:132
Declaration of the Batch Normalization layer class.
Definition: batch_norm.hpp:56
double Momentum() const
Get the momentum value.
Definition: batch_norm.hpp:158
OutputDataType & TrainingMean()
Modify the mean over the training data.
Definition: batch_norm.hpp:144
double Epsilon() const
Get the epsilon value.
Definition: batch_norm.hpp:155
size_t WeightSize() const
Get size of weights.
Definition: batch_norm.hpp:164
OutputDataType const & Delta() const
Get the delta.
Definition: batch_norm.hpp:127