backward_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_BACKWARD_VISITOR_HPP
15 
18 
19 #include <boost/variant.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
28 class BackwardVisitor : public boost::static_visitor<void>
29 {
30  public:
33  BackwardVisitor(const arma::mat& input,
34  const arma::mat& error,
35  arma::mat& delta);
36 
38  BackwardVisitor(const arma::mat& input,
39  const arma::mat& error,
40  arma::mat& delta,
41  const size_t index);
42 
44  template<typename LayerType>
45  void operator()(LayerType* layer) const;
46 
47  void operator()(MoreTypes layer) const;
48 
49  private:
51  const arma::mat& input;
52 
54  const arma::mat& error;
55 
57  arma::mat& delta;
58 
60  size_t index;
61 
63  bool hasIndex;
64 
67  template<typename T>
68  typename std::enable_if<
69  !HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
70  LayerBackward(T* layer, arma::mat& input) const;
71 
73  template<typename T>
74  typename std::enable_if<
75  HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
76  LayerBackward(T* layer, arma::mat& input) const;
77 };
78 
79 } // namespace ann
80 } // namespace mlpack
81 
82 // Include implementation.
83 #include "backward_visitor_impl.hpp"
84 
85 #endif
BackwardVisitor executes the Backward() function given the input, error and delta parameter...
Linear algebra utility functions, generally performed on matrices or vectors.
BackwardVisitor(const arma::mat &input, const arma::mat &error, arma::mat &delta)
Execute the Backward() function given the input, error and delta parameter.
void operator()(LayerType *layer) const
Execute the Backward() function.
boost::variant< FlexibleReLU< arma::mat, arma::mat > *, Linear3D< arma::mat, arma::mat, NoRegularizer > *, LpPooling< arma::mat, arma::mat > *, PixelShuffle< arma::mat, arma::mat > *, ChannelShuffle< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, ReLU6< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > *, ISRLU< arma::mat, arma::mat > *, BicubicInterpolation< arma::mat, arma::mat > *, NearestInterpolation< arma::mat, arma::mat > *, GroupNorm< arma::mat, arma::mat > *> MoreTypes