reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer_types.hpp"
19 #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class Reparametrization
57 {
58  public:
61 
70  Reparametrization(const size_t latentSize,
71  const bool stochastic = true,
72  const bool includeKl = true,
73  const double beta = 1);
74 
77 
80 
83 
86 
94  template<typename eT>
95  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
96 
106  template<typename eT>
107  void Backward(const arma::Mat<eT>& input,
108  const arma::Mat<eT>& gy,
109  arma::Mat<eT>& g);
110 
112  OutputDataType const& OutputParameter() const { return outputParameter; }
114  OutputDataType& OutputParameter() { return outputParameter; }
115 
117  OutputDataType const& Delta() const { return delta; }
119  OutputDataType& Delta() { return delta; }
120 
122  size_t const& OutputSize() const { return latentSize; }
124  size_t& OutputSize() { return latentSize; }
125 
127  double Loss()
128  {
129  if (!includeKl)
130  return 0;
131 
132  return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
133  - arma::pow(mean, 2) + 1) / mean.n_cols;
134  }
135 
137  bool Stochastic() const { return stochastic; }
138 
140  bool IncludeKL() const { return includeKl; }
141 
143  double Beta() const { return beta; }
144 
145  size_t InputShape() const
146  {
147  return 2 * latentSize;
148  }
149 
153  template<typename Archive>
154  void serialize(Archive& ar, const uint32_t /* version */);
155 
156  private:
158  size_t latentSize;
159 
161  bool stochastic;
162 
164  bool includeKl;
165 
167  double beta;
168 
170  OutputDataType delta;
171 
173  OutputDataType gaussianSample;
174 
176  OutputDataType mean;
177 
180  OutputDataType preStdDev;
181 
183  OutputDataType stdDev;
184 
186  OutputDataType outputParameter;
187 }; // class Reparametrization
188 
189 } // namespace ann
190 } // namespace mlpack
191 
192 // Include implementation.
193 #include "reparametrization_impl.hpp"
194 
195 #endif
Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
bool IncludeKL() const
Get the value of the includeKl parameter.
bool Stochastic() const
Get the value of the stochastic parameter.
double Loss()
Get the KL divergence with standard normal.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
size_t const & OutputSize() const
Get the output size.
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t & OutputSize()
Modify the output size.
OutputDataType const & Delta() const
Get the delta.
Reparametrization()
Create the Reparametrization object.
double Beta() const
Get the value of the beta hyperparameter.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.