flatten_t_swish.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_LAYER_FLATTEN_T_SWISH_HPP
15 #define MLPACK_METHODS_ANN_LAYER_FLATTEN_T_SWISH_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace ann {
21 
45 template <
46  typename InputDataType = arma::mat,
47  typename OutputDataType = arma::mat
48 >
50 {
51  public:
59  FlattenTSwish(const double T = -0.20);
60 
68  template<typename InputType, typename OutputType>
69  void Forward(const InputType& input, OutputType& output);
70 
80  template<typename DataType>
81  void Backward(const DataType& input, const DataType& gy, DataType& g);
82 
84  OutputDataType const& OutputParameter() const { return outputParameter; }
86  OutputDataType& OutputParameter() { return outputParameter; }
87 
89  OutputDataType const& Delta() const { return delta; }
91  OutputDataType& Delta() { return delta; }
92 
94  double const& T() const { return t; }
96  double& T() { return t; }
97 
99  size_t WeightSize() const { return 0; }
100 
104  template<typename Archive>
105  void serialize(Archive& ar, const uint32_t /* version */);
106 
107  private:
109  OutputDataType delta;
110 
112  OutputDataType outputParameter;
113 
115  double t;
116 }; // class FlattenTSwish
117 
118 } // namespace ann
119 } // namespace mlpack
120 
121 // Include implementation.
122 #include "flatten_t_swish_impl.hpp"
123 
124 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
The Flatten T Swish activation function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
double const & T() const
Get the T parameter.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t WeightSize() const
Get size of weights.
OutputDataType & Delta()
Modify the delta.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Delta() const
Get the delta.
OutputDataType const & OutputParameter() const
Get the output parameter.
FlattenTSwish(const double T=-0.20)
Create the Flatten T Swish object using the specified parameters.
OutputDataType & OutputParameter()
Modify the output parameter.
void Backward(const DataType &input, const DataType &gy, DataType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
double & T()
Modify the T parameter.