celu.hpp
Go to the documentation of this file.
1 
23 #ifndef MLPACK_METHODS_ANN_LAYER_CELU_HPP
24 #define MLPACK_METHODS_ANN_LAYER_CELU_HPP
25 
26 #include <mlpack/prereqs.hpp>
27 
28 namespace mlpack {
29 namespace ann {
30 
56 template <
57  typename InputDataType = arma::mat,
58  typename OutputDataType = arma::mat
59 >
60 class CELU
61 {
62  public:
70  CELU(const double alpha = 1.0);
71 
79  template<typename InputType, typename OutputType>
80  void Forward(const InputType& input, OutputType& output);
81 
91  template<typename DataType>
92  void Backward(const DataType& input, const DataType& gy, DataType& g);
93 
95  OutputDataType const& OutputParameter() const { return outputParameter; }
97  OutputDataType& OutputParameter() { return outputParameter; }
98 
100  OutputDataType const& Delta() const { return delta; }
102  OutputDataType& Delta() { return delta; }
103 
105  double const& Alpha() const { return alpha; }
107  double& Alpha() { return alpha; }
108 
110  bool Deterministic() const { return deterministic; }
112  bool& Deterministic() { return deterministic; }
113 
115  size_t WeightSize() { return 0; }
116 
120  template<typename Archive>
121  void serialize(Archive& ar, const uint32_t /* version */);
122 
123  private:
125  OutputDataType delta;
126 
128  OutputDataType outputParameter;
129 
131  arma::mat derivative;
132 
134  double alpha;
135 
137  bool deterministic;
138 }; // class CELU
139 
140 } // namespace ann
141 } // namespace mlpack
142 
143 // Include implementation.
144 #include "celu_impl.hpp"
145 
146 #endif
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: celu.hpp:97
OutputDataType const & Delta() const
Get the delta.
Definition: celu.hpp:100
double & Alpha()
Modify the non zero gradient.
Definition: celu.hpp:107
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: celu.hpp:95
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & Delta()
Modify the delta.
Definition: celu.hpp:102
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: celu.hpp:112
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
The CELU activation function, defined by.
Definition: celu.hpp:60
size_t WeightSize()
Get size of weights.
Definition: celu.hpp:115
CELU(const double alpha=1.0)
Create the CELU object using the specified parameter.
double const & Alpha() const
Get the non zero gradient.
Definition: celu.hpp:105
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
bool Deterministic() const
Get the value of deterministic parameter.
Definition: celu.hpp:110