set_input_width_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_SET_INPUT_WIDTH_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_SET_INPUT_WIDTH_VISITOR_HPP
15 
17 
18 #include <boost/variant.hpp>
19 
20 namespace mlpack {
21 namespace ann {
22 
27 class SetInputWidthVisitor : public boost::static_visitor<bool>
28 {
29  public:
31  SetInputWidthVisitor(const size_t inputWidth = 0, const bool reset = false);
32 
34  template<typename LayerType>
35  bool operator()(LayerType* layer) const;
36 
37  bool operator()(MoreTypes layer) const;
38 
39  private:
41  size_t inputWidth;
42 
44  bool reset;
45 
48  template<typename T>
49  typename std::enable_if<
50  !HasInputWidth<T, size_t&(T::*)()>::value &&
51  !HasModelCheck<T>::value, bool>::type
52  LayerInputWidth(T* layer) const;
53 
55  template<typename T>
56  typename std::enable_if<
57  HasInputWidth<T, size_t&(T::*)()>::value &&
58  !HasModelCheck<T>::value, bool>::type
59  LayerInputWidth(T* layer) const;
60 
62  template<typename T>
63  typename std::enable_if<
64  !HasInputWidth<T, size_t&(T::*)()>::value &&
65  HasModelCheck<T>::value, bool>::type
66  LayerInputWidth(T* layer) const;
67 
70  template<typename T>
71  typename std::enable_if<
72  HasInputWidth<T, size_t&(T::*)()>::value &&
73  HasModelCheck<T>::value, bool>::type
74  LayerInputWidth(T* layer) const;
75 };
76 
77 } // namespace ann
78 } // namespace mlpack
79 
80 // Include implementation.
81 #include "set_input_width_visitor_impl.hpp"
82 
83 #endif
SetInputWidthVisitor(const size_t inputWidth=0, const bool reset=false)
Update the input width parameter with the given input width.
Linear algebra utility functions, generally performed on matrices or vectors.
SetInputWidthVisitor updates the input width parameter with the given input width.
bool operator()(LayerType *layer) const
Update the input width parameter.
boost::variant< FlexibleReLU< arma::mat, arma::mat > *, Linear3D< arma::mat, arma::mat, NoRegularizer > *, LpPooling< arma::mat, arma::mat > *, PixelShuffle< arma::mat, arma::mat > *, ChannelShuffle< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, ReLU6< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > *, ISRLU< arma::mat, arma::mat > *, BicubicInterpolation< arma::mat, arma::mat > *, NearestInterpolation< arma::mat, arma::mat > *, GroupNorm< arma::mat, arma::mat > *> MoreTypes