linear_svm_function.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
15 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace svm {
21 
27 template <typename MatType = arma::mat>
29 {
30  public:
41  LinearSVMFunction(const MatType& dataset,
42  const arma::Row<size_t>& labels,
43  const size_t numClasses,
44  const double lambda = 0.0001,
45  const double delta = 1.0,
46  const bool fitIntercept = false);
47 
51  void Shuffle();
52 
62  static void InitializeWeights(arma::mat& weights,
63  const size_t featureSize,
64  const size_t numClasses,
65  const bool fitIntercept = false);
66 
73  void GetGroundTruthMatrix(const arma::Row<size_t>& labels,
74  arma::sp_mat& groundTruth);
75 
82  double Evaluate(const arma::mat& parameters);
83 
93  double Evaluate(const arma::mat& parameters,
94  const size_t firstId,
95  const size_t batchSize = 1);
96 
105  template <typename GradType>
106  void Gradient(const arma::mat& parameters,
107  GradType& gradient);
108 
119  template <typename GradType>
120  void Gradient(const arma::mat& parameters,
121  const size_t firstId,
122  GradType& gradient,
123  const size_t batchSize = 1);
124 
136  template <typename GradType>
137  double EvaluateWithGradient(const arma::mat& parameters,
138  GradType& gradient) const;
139 
154  template <typename GradType>
155  double EvaluateWithGradient(const arma::mat& parameters,
156  const size_t firstId,
157  GradType& gradient,
158  const size_t batchSize = 1) const;
159 
161  const arma::mat& InitialPoint() const { return initialPoint; }
163  arma::mat& InitialPoint() { return initialPoint; }
164 
166  const arma::sp_mat& Dataset() const { return dataset; }
168  arma::sp_mat& Dataset() { return dataset; }
169 
171  double& Lambda() { return lambda; }
173  double Lambda() const { return lambda; }
174 
176  bool FitIntercept() const { return fitIntercept; }
177 
179  size_t NumFunctions() const;
180 
181  private:
183  arma::mat initialPoint;
184 
186  arma::sp_mat groundTruth;
187 
189  MatType dataset;
190 
192  size_t numClasses;
193 
195  double lambda;
196 
198  double delta;
199 
201  bool fitIntercept;
202 };
203 
204 } // namespace svm
205 } // namespace mlpack
206 
207 // Include implementation
208 #include "linear_svm_function_impl.hpp"
209 
210 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_FUNCTION_HPP
The hinge loss function for the linear SVM objective function.
LinearSVMFunction(const MatType &dataset, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda=0.0001, const double delta=1.0, const bool fitIntercept=false)
Construct the Linear SVM objective function with given parameters.
double Lambda() const
Gets the regularization parameter.
static void InitializeWeights(arma::mat &weights, const size_t featureSize, const size_t numClasses, const bool fitIntercept=false)
Initialize Linear SVM weights (trainable parameters) with the given parameters.
Linear algebra utility functions, generally performed on matrices or vectors.
bool FitIntercept() const
Gets the intercept flag.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double EvaluateWithGradient(const arma::mat &parameters, GradType &gradient) const
Evaluate the gradient of the hinge loss function, following the LinearFunctionType requirements on th...
arma::mat & InitialPoint()
Modify the initial point for the optimization.
arma::sp_mat & Dataset()
Modify the dataset.
void Shuffle()
Shuffle the dataset.
size_t NumFunctions() const
Return the number of functions.
void GetGroundTruthMatrix(const arma::Row< size_t > &labels, arma::sp_mat &groundTruth)
Constructs the ground truth label matrix with the passed labels.
const arma::mat & InitialPoint() const
Return the initial point for the optimization.
double & Lambda()
Sets the regularization parameter.
void Gradient(const arma::mat &parameters, GradType &gradient)
Evaluate the gradient of the hinge loss function following the LinearFunctionType requirements on the...
double Evaluate(const arma::mat &parameters)
Evaluate the hinge loss function for all the datapoints.
const arma::sp_mat & Dataset() const
Get the dataset.