binary_cross_entropy_loss.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_CROSS_ENTROPY_ERROR_HPP
13 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_CROSS_ENTROPY_ERROR_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
29 template <
30  typename InputDataType = arma::mat,
31  typename OutputDataType = arma::mat
32 >
33 class BCELoss
34 {
35  public:
44  BCELoss(const double eps = 1e-10, const bool reduction = true);
45 
53  template<typename PredictionType, typename TargetType>
54  typename PredictionType::elem_type Forward(const PredictionType& prediction,
55  const TargetType& target);
56 
65  template<typename PredictionType, typename TargetType, typename LossType>
66  void Backward(const PredictionType& prediction,
67  const TargetType& target,
68  LossType& loss);
69 
71  OutputDataType& OutputParameter() const { return outputParameter; }
73  OutputDataType& OutputParameter() { return outputParameter; }
74 
76  double Eps() const { return eps; }
78  double& Eps() { return eps; }
79 
81  bool Reduction() const { return reduction; }
83  bool& Reduction() { return reduction; }
84 
88  template<typename Archive>
89  void serialize(Archive& ar, const uint32_t /* version */);
90 
91  private:
93  OutputDataType outputParameter;
94 
96  double eps;
97 
99  bool reduction;
100 }; // class BCELoss
101 
105 template <
106  typename InputDataType = arma::mat,
107  typename OutputDataType = arma::mat
108 >
109 using CrossEntropyError = BCELoss<
110  InputDataType, OutputDataType>;
111 
112 } // namespace ann
113 } // namespace mlpack
114 
115 // Include implementation.
116 #include "binary_cross_entropy_loss_impl.hpp"
117 
118 #endif
void Backward(const PredictionType &prediction, const TargetType &target, LossType &loss)
Ordinary feed backward pass of a neural network.
BCELoss(const double eps=1e-10, const bool reduction=true)
Create the BinaryCrossEntropyLoss object.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool Reduction() const
Get the reduction.
bool & Reduction()
Set the reduction.
OutputDataType & OutputParameter() const
Get the output parameter.
PredictionType::elem_type Forward(const PredictionType &prediction, const TargetType &target)
Computes the cross-entropy function.
OutputDataType & OutputParameter()
Modify the output parameter.
double Eps() const
Get the epsilon.
The binary-cross-entropy performance function measures the Binary Cross Entropy between the target an...
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
double & Eps()
Modify the epsilon.