12 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP 13 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP 16 #include <ensmallen.hpp> 79 template <
typename MatType = arma::mat>
101 template <
typename OptimizerType,
typename... CallbackTypes>
103 const arma::Row<size_t>& labels,
104 const size_t numClasses,
107 const bool fitIntercept,
108 OptimizerType optimizer,
109 CallbackTypes&&... callbacks);
126 template <
typename OptimizerType = ens::L_BFGS>
128 const arma::Row<size_t>& labels,
129 const size_t numClasses = 2,
130 const double lambda = 0.0001,
131 const double delta = 1.0,
132 const bool fitIntercept =
false,
133 OptimizerType optimizer = OptimizerType());
147 const size_t numClasses = 0,
148 const double lambda = 0.0001,
149 const double delta = 1.0,
150 const bool fitIntercept =
false);
162 const double lambda = 0.0001,
163 const double delta = 1.0,
164 const bool fitIntercept =
false);
176 arma::Row<size_t>& labels)
const;
190 arma::Row<size_t>& labels,
191 arma::mat& scores)
const;
200 arma::mat& scores)
const;
210 template<
typename VecType>
211 size_t Classify(
const VecType& point)
const;
223 const arma::Row<size_t>& testLabels)
const;
238 template <
typename OptimizerType,
typename... CallbackTypes>
239 double Train(
const MatType& data,
240 const arma::Row<size_t>& labels,
241 const size_t numClasses,
242 OptimizerType optimizer,
243 CallbackTypes&&... callbacks);
255 template <
typename OptimizerType = ens::L_BFGS>
256 double Train(
const MatType& data,
257 const arma::Row<size_t>& labels,
258 const size_t numClasses = 2,
259 OptimizerType optimizer = OptimizerType());
275 double Delta()
const {
return delta; }
287 {
return fitIntercept ? parameters.n_rows - 1 :
293 template<
typename Archive>
296 ar(CEREAL_NVP(parameters));
297 ar(CEREAL_NVP(numClasses));
298 ar(CEREAL_NVP(lambda));
299 ar(CEREAL_NVP(fitIntercept));
304 arma::mat parameters;
319 #include "linear_svm_impl.hpp" 321 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
bool & FitIntercept()
Sets the intercept term flag.
Linear algebra utility functions, generally performed on matrices or vectors.
double Delta() const
Gets the margin between the correct class and all other classes.
arma::mat & Parameters()
Set the model parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double & Lambda()
Sets the regularization parameter.
size_t & NumClasses()
Sets the number of classes.
double & Delta()
Sets the margin between the correct class and all other classes.
size_t FeatureSize() const
Gets the features size of the training data.
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const double delta, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the LinearSVM class with the provided data and labels.
const arma::mat & Parameters() const
Get the model parameters.
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
void serialize(Archive &ar, const uint32_t)
Serialize the LinearSVM model.
The LinearSVM class implements an L2-regularized support vector machine model, and supports training ...
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the Linear SVM with the given training data.
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
double Lambda() const
Gets the regularization parameter.
size_t NumClasses() const
Gets the number of classes.