base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
33 
34 namespace mlpack {
35 namespace ann {
36 
66 template <
67  class ActivationFunction = LogisticFunction,
68  typename InputDataType = arma::mat,
69  typename OutputDataType = arma::mat
70 >
71 class BaseLayer
72 {
73  public:
78  {
79  // Nothing to do here.
80  }
81 
89  template<typename InputType, typename OutputType>
90  void Forward(const InputType& input, OutputType& output)
91  {
92  ActivationFunction::Fn(input, output);
93  }
94 
104  template<typename eT>
105  void Backward(const arma::Mat<eT>& input,
106  const arma::Mat<eT>& gy,
107  arma::Mat<eT>& g)
108  {
109  arma::Mat<eT> derivative;
110  ActivationFunction::Deriv(input, derivative);
111  g = gy % derivative;
112  }
113 
115  OutputDataType const& OutputParameter() const { return outputParameter; }
117  OutputDataType& OutputParameter() { return outputParameter; }
118 
120  OutputDataType const& Delta() const { return delta; }
122  OutputDataType& Delta() { return delta; }
123 
127  template<typename Archive>
128  void serialize(Archive& /* ar */, const uint32_t /* version */)
129  {
130  /* Nothing to do here */
131  }
132 
133  private:
135  OutputDataType delta;
136 
138  OutputDataType outputParameter;
139 }; // class BaseLayer
140 
141 // Convenience typedefs.
142 
146 template <
147  class ActivationFunction = LogisticFunction,
148  typename InputDataType = arma::mat,
149  typename OutputDataType = arma::mat
150 >
151 using SigmoidLayer = BaseLayer<
152  ActivationFunction, InputDataType, OutputDataType>;
153 
157 template <
158  class ActivationFunction = IdentityFunction,
159  typename InputDataType = arma::mat,
160  typename OutputDataType = arma::mat
161 >
162 using IdentityLayer = BaseLayer<
163  ActivationFunction, InputDataType, OutputDataType>;
164 
168 template <
169  class ActivationFunction = RectifierFunction,
170  typename InputDataType = arma::mat,
171  typename OutputDataType = arma::mat
172 >
173 using ReLULayer = BaseLayer<
174  ActivationFunction, InputDataType, OutputDataType>;
175 
179 template <
180  class ActivationFunction = TanhFunction,
181  typename InputDataType = arma::mat,
182  typename OutputDataType = arma::mat
183 >
184 using TanHLayer = BaseLayer<
185  ActivationFunction, InputDataType, OutputDataType>;
186 
190 template <
191  class ActivationFunction = SoftplusFunction,
192  typename InputDataType = arma::mat,
193  typename OutputDataType = arma::mat
194 >
195 using SoftPlusLayer = BaseLayer<
196  ActivationFunction, InputDataType, OutputDataType>;
197 
201 template <
202  class ActivationFunction = HardSigmoidFunction,
203  typename InputDataType = arma::mat,
204  typename OutputDataType = arma::mat
205 >
207  ActivationFunction, InputDataType, OutputDataType>;
208 
212 template <
213  class ActivationFunction = SwishFunction,
214  typename InputDataType = arma::mat,
215  typename OutputDataType = arma::mat
216 >
218  ActivationFunction, InputDataType, OutputDataType>;
219 
223 template <
224  class ActivationFunction = MishFunction,
225  typename InputDataType = arma::mat,
226  typename OutputDataType = arma::mat
227 >
229  ActivationFunction, InputDataType, OutputDataType>;
230 
234 template <
235  class ActivationFunction = LiSHTFunction,
236  typename InputDataType = arma::mat,
237  typename OutputDataType = arma::mat
238 >
240  ActivationFunction, InputDataType, OutputDataType>;
241 
245 template <
246  class ActivationFunction = GELUFunction,
247  typename InputDataType = arma::mat,
248  typename OutputDataType = arma::mat
249 >
251  ActivationFunction, InputDataType, OutputDataType>;
252 
256 template <
257  class ActivationFunction = ElliotFunction,
258  typename InputDataType = arma::mat,
259  typename OutputDataType = arma::mat
260 >
262  ActivationFunction, InputDataType, OutputDataType>;
263 
267 template <
268  class ActivationFunction = ElishFunction,
269  typename InputDataType = arma::mat,
270  typename OutputDataType = arma::mat
271 >
273  ActivationFunction, InputDataType, OutputDataType>;
274 
278 template <
279  class ActivationFunction = GaussianFunction,
280  typename InputDataType = arma::mat,
281  typename OutputDataType = arma::mat
282 >
284  ActivationFunction, InputDataType, OutputDataType>;
285 
289 template <
290  class ActivationFunction = HardSwishFunction,
291  typename InputDataType = arma::mat,
292  typename OutputDataType = arma::mat
293 >
295  ActivationFunction, InputDataType, OutputDataType>;
296 
300 template <
301  class ActivationFunction = TanhExpFunction,
302  typename InputDataType = arma::mat,
303  typename OutputDataType = arma::mat
304 >
306  ActivationFunction, InputDataType, OutputDataType>;
307 
311 template <
312  class ActivationFunction = SILUFunction,
313  typename InputDataType = arma::mat,
314  typename OutputDataType = arma::mat
315 >
317  ActivationFunction, InputDataType, OutputDataType
318 >;
319 
320 } // namespace ann
321 } // namespace mlpack
322 
323 #endif
The identity function, defined by.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:90
The Hard Swish function, defined by.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:117
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:77
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:122
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:105
The tanh function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: base_layer.hpp:128
The ELiSH function, defined by.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:115
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:120
Implementation of the base layer.
Definition: base_layer.hpp:71
The SILU function, defined by.
The Mish function, defined by.
The TanhExp function, defined by.
The logistic function, defined by.
The gaussian function, defined by.
The Elliot function, defined by.
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.