multilabel_softmargin_loss.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP
17 #define MLPACK_ANN_LOSS_FUNCTION_MULTILABEL_SOFTMARGIN_LOSS_HPP
18 
19 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace ann {
23 
30 template <
31  typename InputDataType = arma::mat,
32  typename OutputDataType = arma::mat
33 >
35 {
36  public:
48  MultiLabelSoftMarginLoss(const bool reduction = true,
49  const arma::rowvec& weights = arma::rowvec());
50 
57  template<typename InputType, typename TargetType>
58  typename InputType::elem_type Forward(const InputType& input,
59  const TargetType& target);
60 
68  template<typename InputType, typename TargetType, typename OutputType>
69  void Backward(const InputType& input,
70  const TargetType& target,
71  OutputType& output);
72 
74  OutputDataType& OutputParameter() const { return outputParameter; }
76  OutputDataType& OutputParameter() { return outputParameter; }
77 
79  const arma::rowvec& ClassWeights() const { return classWeights; }
81  arma::rowvec& ClassWeights() { return classWeights; }
82 
84  bool Reduction() const { return reduction; }
86  bool& Reduction() { return reduction; }
87 
91  template<typename Archive>
92  void serialize(Archive& ar, const unsigned int /* version */);
93 
94  private:
96  OutputDataType outputParameter;
97 
99  bool reduction;
100 
102  arma::rowvec classWeights;
103 
104  // An internal parameter used during initialisation of class weights.
105  bool weighted;
106 }; // class MultiLabelSoftMarginLoss
107 
108 } // namespace ann
109 } // namespace mlpack
110 
111 // include implementation.
112 #include "multilabel_softmargin_loss_impl.hpp"
113 
114 #endif
bool & Reduction()
Modify the type of reduction used.
void Backward(const InputType &input, const TargetType &target, OutputType &output)
Ordinary feed backward pass of a neural network.
OutputDataType & OutputParameter() const
Get the output parameter.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & OutputParameter()
Modify the output parameter.
InputType::elem_type Forward(const InputType &input, const TargetType &target)
Computes the Multi Label Soft Margin Loss function.
bool Reduction() const
Get the type of reduction used.
arma::rowvec & ClassWeights()
Modify the weights assigned to each class.
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
MultiLabelSoftMarginLoss(const bool reduction=true, const arma::rowvec &weights=arma::rowvec())
Create the MultiLabelSoftMarginLoss object.
const arma::rowvec & ClassWeights() const
Get the weights assigned to each class.