13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP 14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP 19 #include "../activation_functions/softplus_function.hpp" 53 typename InputDataType = arma::mat,
54 typename OutputDataType = arma::mat
56 class Reparametrization
71 const bool stochastic =
true,
72 const bool includeKl =
true,
73 const double beta = 1);
95 void Forward(
const arma::Mat<eT>& input, arma::Mat<eT>& output);
106 template<
typename eT>
107 void Backward(
const arma::Mat<eT>& input,
108 const arma::Mat<eT>& gy,
117 OutputDataType
const&
Delta()
const {
return delta; }
119 OutputDataType&
Delta() {
return delta; }
132 return -0.5 * beta * arma::accu(2 * arma::log(stdDev) - arma::pow(stdDev, 2)
133 - arma::pow(mean, 2) + 1) / mean.n_cols;
143 double Beta()
const {
return beta; }
147 return 2 * latentSize;
153 template<
typename Archive>
154 void serialize(Archive& ar,
const uint32_t );
170 OutputDataType delta;
173 OutputDataType gaussianSample;
180 OutputDataType preStdDev;
183 OutputDataType stdDev;
186 OutputDataType outputParameter;
193 #include "reparametrization_impl.hpp" Reparametrization & operator=(const Reparametrization &layer)
Copy assignment operator.
OutputDataType & Delta()
Modify the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
bool IncludeKL() const
Get the value of the includeKl parameter.
bool Stochastic() const
Get the value of the stochastic parameter.
double Loss()
Get the KL divergence with standard normal.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
size_t const & OutputSize() const
Get the output size.
OutputDataType const & OutputParameter() const
Get the output parameter.
size_t InputShape() const
size_t & OutputSize()
Modify the output size.
OutputDataType const & Delta() const
Get the delta.
Reparametrization()
Create the Reparametrization object.
double Beta() const
Get the value of the beta hyperparameter.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.