reinforce_normal.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REINFORCE_NORMAL_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
30 template <
31  typename InputDataType = arma::mat,
32  typename OutputDataType = arma::mat
33 >
35 {
36  public:
42  ReinforceNormal(const double stdev = 1.0);
43 
51  template<typename eT>
52  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
53 
63  template<typename DataType>
64  void Backward(const DataType& input, const DataType& /* gy */, DataType& g);
65 
67  OutputDataType& OutputParameter() const { return outputParameter; }
69  OutputDataType& OutputParameter() { return outputParameter; }
70 
72  OutputDataType& Delta() const { return delta; }
74  OutputDataType& Delta() { return delta; }
75 
77  bool Deterministic() const { return deterministic; }
79  bool& Deterministic() { return deterministic; }
80 
82  double Reward() const { return reward; }
84  double& Reward() { return reward; }
85 
87  double StandardDeviation() const { return stdev; }
88 
92  template<typename Archive>
93  void serialize(Archive& ar, const uint32_t /* version */);
94 
95  private:
97  double stdev;
98 
100  double reward;
101 
103  OutputDataType delta;
104 
106  OutputDataType outputParameter;
107 
109  std::vector<arma::mat> moduleInputParameter;
110 
112  bool deterministic;
113 }; // class ReinforceNormal
114 
115 } // namespace ann
116 } // namespace mlpack
117 
118 // Include implementation.
119 #include "reinforce_normal_impl.hpp"
120 
121 #endif
ReinforceNormal(const double stdev=1.0)
Create the ReinforceNormal object.
OutputDataType & Delta() const
Get the delta.
Linear algebra utility functions, generally performed on matrices or vectors.
Implementation of the reinforce normal layer.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double StandardDeviation() const
Get the standard deviation used during forward and backward pass.
void Backward(const DataType &input, const DataType &, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
OutputDataType & Delta()
Modify the delta.
double & Reward()
Modify the value of the deterministic parameter.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
double Reward() const
Get the value of the reward parameter.
bool & Deterministic()
Modify the value of the deterministic parameter.
bool Deterministic() const
Get the value of the deterministic parameter.
OutputDataType & OutputParameter() const
Get the output parameter.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...