loss_visitor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_HPP
14 #define MLPACK_METHODS_ANN_VISITOR_LOSS_VISITOR_HPP
15 
17 
18 #include <boost/variant.hpp>
19 
20 namespace mlpack {
21 namespace ann {
22 
26 class LossVisitor : public boost::static_visitor<double>
27 {
28  public:
30  template<typename LayerType>
31  double operator()(LayerType* layer) const;
32 
33  double operator()(MoreTypes layer) const;
34 
35  private:
37  template<typename T>
38  typename std::enable_if<
39  !HasLoss<T, double(T::*)()>::value &&
40  !HasModelCheck<T>::value, double>::type
41  LayerLoss(T* layer) const;
42 
44  template<typename T>
45  typename std::enable_if<
46  HasLoss<T, double(T::*)()>::value &&
47  !HasModelCheck<T>::value, double>::type
48  LayerLoss(T* layer) const;
49 
51  template<typename T>
52  typename std::enable_if<
53  !HasLoss<T, double(T::*)()>::value &&
54  HasModelCheck<T>::value, double>::type
55  LayerLoss(T* layer) const;
56 
58  template<typename T>
59  typename std::enable_if<
60  HasLoss<T, double(T::*)()>::value &&
61  HasModelCheck<T>::value, double>::type
62  LayerLoss(T* layer) const;
63 };
64 
65 } // namespace ann
66 } // namespace mlpack
67 
68 // Include implementation.
69 #include "loss_visitor_impl.hpp"
70 
71 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
LossVisitor exposes the Loss() method of the given module.
double operator()(LayerType *layer) const
Return the Loss.
boost::variant< FlexibleReLU< arma::mat, arma::mat > *, Linear3D< arma::mat, arma::mat, NoRegularizer > *, LpPooling< arma::mat, arma::mat > *, PixelShuffle< arma::mat, arma::mat > *, ChannelShuffle< arma::mat, arma::mat > *, Glimpse< arma::mat, arma::mat > *, Highway< arma::mat, arma::mat > *, MultiheadAttention< arma::mat, arma::mat, NoRegularizer > *, Recurrent< arma::mat, arma::mat > *, RecurrentAttention< arma::mat, arma::mat > *, ReinforceNormal< arma::mat, arma::mat > *, ReLU6< arma::mat, arma::mat > *, Reparametrization< arma::mat, arma::mat > *, Select< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, Subview< arma::mat, arma::mat > *, VRClassReward< arma::mat, arma::mat > *, VirtualBatchNorm< arma::mat, arma::mat > *, RBF< arma::mat, arma::mat, GaussianFunction > *, BaseLayer< GaussianFunction, arma::mat, arma::mat > *, PositionalEncoding< arma::mat, arma::mat > *, ISRLU< arma::mat, arma::mat > *, BicubicInterpolation< arma::mat, arma::mat > *, NearestInterpolation< arma::mat, arma::mat > *, GroupNorm< arma::mat, arma::mat > *> MoreTypes