reward_set_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_HPP
15 
17 
18 #include <boost/variant.hpp>
19 
20 namespace mlpack {
21 namespace ann {
22 
26 class RewardSetVisitor : public boost::static_visitor<void>
27 {
28  public:
30  RewardSetVisitor(const double reward);
31 
33  template<typename LayerType>
34  void operator()(LayerType* layer) const;
35 
36  void operator()(MoreTypes layer) const;
37 
38  private:
40  const double reward;
41 
44  template<typename T>
45  typename std::enable_if<
46  HasRewardCheck<T, double&(T::*)()>::value &&
47  HasModelCheck<T>::value, void>::type
48  LayerReward(T* layer) const;
49 
52  template<typename T>
53  typename std::enable_if<
54  !HasRewardCheck<T, double&(T::*)()>::value &&
55  HasModelCheck<T>::value, void>::type
56  LayerReward(T* layer) const;
57 
60  template<typename T>
61  typename std::enable_if<
62  HasRewardCheck<T, double&(T::*)()>::value &&
63  !HasModelCheck<T>::value, void>::type
64  LayerReward(T* layer) const;
65 
68  template<typename T>
69  typename std::enable_if<
70  !HasRewardCheck<T, double&(T::*)()>::value &&
71  !HasModelCheck<T>::value, void>::type
72  LayerReward(T* layer) const;
73 };
74 
75 } // namespace ann
76 } // namespace mlpack
77 
78 // Include implementation.
79 #include "reward_set_visitor_impl.hpp"
80 
81 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
RewardSetVisitor set the reward parameter given the reward value.
RewardSetVisitor(const double reward)
Set the reward parameter given the reward value.
void operator()(LayerType *layer) const
Set the reward parameter.
boost::variant< FlexibleReLU< arma::mat, arma::mat > *, Linear3D< arma::mat, arma::mat, NoRegularizer > *, LpPooling< arma::mat, arma::mat > *, PixelShuffle< arma::mat, arma::mat > *, ChannelShuffle< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, ReLU6< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > *, ISRLU< arma::mat, arma::mat > *, BicubicInterpolation< arma::mat, arma::mat > *, NearestInterpolation< arma::mat, arma::mat > *, GroupNorm< arma::mat, arma::mat > *> MoreTypes