gradient_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_GRADIENT_VISITOR_HPP
15 
18 
19 #include <boost/variant.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
28 class GradientVisitor : public boost::static_visitor<void>
29 {
30  public:
33  GradientVisitor(const arma::mat& input, const arma::mat& delta);
34 
36  GradientVisitor(const arma::mat& input,
37  const arma::mat& delta,
38  const size_t index);
39 
41  template<typename LayerType>
42  void operator()(LayerType* layer) const;
43 
44  void operator()(MoreTypes layer) const;
45 
46  private:
48  const arma::mat& input;
49 
51  const arma::mat& delta;
52 
54  size_t index;
55 
57  bool hasIndex;
58 
61  template<typename T>
62  typename std::enable_if<
63  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
64  !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
65  LayerGradients(T* layer, arma::mat& input) const;
66 
69  template<typename T>
70  typename std::enable_if<
71  HasGradientCheck<T, arma::mat&(T::*)()>::value &&
72  HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
73  LayerGradients(T* layer, arma::mat& input) const;
74 
77  template<typename T, typename P>
78  typename std::enable_if<
79  !HasGradientCheck<T, P&(T::*)()>::value, void>::type
80  LayerGradients(T* layer, P& input) const;
81 };
82 
83 } // namespace ann
84 } // namespace mlpack
85 
86 // Include implementation.
87 #include "gradient_visitor_impl.hpp"
88 
89 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
GradientVisitor(const arma::mat &input, const arma::mat &delta)
Executes the Gradient() method of the given module using the input and delta parameter.
SearchModeVisitor executes the Gradient() method of the given module using the input and delta parame...
void operator()(LayerType *layer) const
Executes the Gradient() method.
boost::variant< FlexibleReLU< arma::mat, arma::mat > *, Linear3D< arma::mat, arma::mat, NoRegularizer > *, LpPooling< arma::mat, arma::mat > *, PixelShuffle< arma::mat, arma::mat > *, ChannelShuffle< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, ReLU6< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > *, ISRLU< arma::mat, arma::mat > *, BicubicInterpolation< arma::mat, arma::mat > *, NearestInterpolation< arma::mat, arma::mat > *, GroupNorm< arma::mat, arma::mat > *> MoreTypes