earth_mover_distance.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_EARTH_MOVER_DISTANCE_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_EARTH_MOVER_DISTANCE_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
29 template <
30  typename InputDataType = arma::mat,
31  typename OutputDataType = arma::mat
32 >
34 {
35  public:
40 
48  template<typename PredictionType, typename TargetType>
49  typename PredictionType::elem_type Forward(const PredictionType& prediction,
50  const TargetType& target);
51 
60  template<typename PredictionType, typename TargetType, typename LossType>
61  void Backward(const PredictionType& prediction,
62  const TargetType& target,
63  LossType& loss);
64 
66  OutputDataType& OutputParameter() const { return outputParameter; }
68  OutputDataType& OutputParameter() { return outputParameter; }
69 
73  template<typename Archive>
74  void serialize(Archive& ar, const uint32_t /* version */);
75 
76  private:
78  OutputDataType outputParameter;
79 }; // class EarthMoverDistance
80 
81 } // namespace ann
82 } // namespace mlpack
83 
84 // Include implementation.
85 #include "earth_mover_distance_impl.hpp"
86 
87 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The earth mover distance function measures the network&#39;s performance according to the Kantorovich-Rub...
EarthMoverDistance()
Create the EarthMoverDistance object.
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Ordinary feed forward pass of a neural network.
OutputDataType & OutputParameter()
Modify the output parameter.
OutputDataType & OutputParameter() const
Get the output parameter.