swish_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SWISH_FUNCTION_HPP
14 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_SWISH_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
31 {
32  public:
39  static double Fn(const double x)
40  {
41  return x / (1.0 + std::exp(-x));
42  }
43 
50  template<typename eT>
51  static void Fn(const arma::Mat<eT>& x, arma::Mat<eT>& y)
52  {
53  y = x / (1.0 + arma::exp(-x));
54  }
55 
62  template<typename InputVecType, typename OutputVecType>
63  static void Fn(const InputVecType& x, OutputVecType& y)
64  {
65  y.set_size(arma::size(x));
66 
67  for (size_t i = 0; i < x.n_elem; ++i)
68  y(i) = Fn(x(i));
69  }
70 
77  static double Deriv(const double y)
78  {
79  return y / (1 + std::exp(-y)) + (1 - y / (1 + std::exp(-y))) /
80  (1 + std::exp(-y));
81  }
82 
89  template<typename InputVecType, typename OutputVecType>
90  static void Deriv(const InputVecType& y, OutputVecType& x)
91  {
92  x = y / (1 + arma::exp(-y)) + (1 - y / (1 + arma::exp(-y))) /
93  (1 + arma::exp(-y));
94  }
95 }; // class SwishFunction
96 
97 } // namespace ann
98 } // namespace mlpack
99 
100 #endif
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the swish function.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Fn(const double x)
Computes the swish function.
static double Deriv(const double y)
Computes the first derivative of the swish function.
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the swish function.
The swish function, defined by.
static void Fn(const arma::Mat< eT > &x, arma::Mat< eT > &y)
Computes the swish function using a matrix as input.