gelu_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
14 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
32 {
33  public:
40  static double Fn(const double x)
41  {
42  return 0.5 * x * (1 + std::tanh(std::sqrt(2 / M_PI) *
43  (x + 0.044715 * std::pow(x, 3))));
44  }
45 
52  template<typename InputVecType, typename OutputVecType>
53  static void Fn(const InputVecType& x, OutputVecType& y)
54  {
55  y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 / M_PI) *
56  (x + 0.044715 * arma::pow(x, 3))));
57  }
58 
65  static double Deriv(const double y)
66  {
67  return 0.5 * std::tanh(0.0356774 * std::pow(y, 3) + 0.797885 * y) +
68  (0.0535161 * std::pow(y, 3) + 0.398942 * y) *
69  std::pow(1 / std::cosh(0.0356774 * std::pow(y, 3) +
70  0.797885 * y), 2) + 0.5;
71  }
72 
79  template<typename InputVecType, typename OutputVecType>
80  static void Deriv(const InputVecType& y, OutputVecType& x)
81  {
82  x = 0.5 * arma::tanh(0.0356774 * arma::pow(y, 3) + 0.797885 * y) +
83  (0.0535161 * arma::pow(y, 3) + 0.398942 * y) %
84  arma::pow(1 / arma::cosh(0.0356774 * arma::pow(y, 3) +
85  0.797885 * y), 2) + 0.5;
86  }
87 }; // class GELUFunction
88 
89 } // namespace ann
90 } // namespace mlpack
91 
92 #endif
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the GELU function.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
#define M_PI
Definition: prereqs.hpp:39
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the GELU function.
static double Deriv(const double y)
Computes the first derivative of the GELU function.
static double Fn(const double x)
Computes the GELU function.
The GELU function, defined by.