softmax_regression_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_FUNCTION_HPP
14 #define MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace regression {
20 
22 {
23  public:
34  SoftmaxRegressionFunction(const arma::mat& data,
35  const arma::Row<size_t>& labels,
36  const size_t numClasses,
37  const double lambda = 0.0001,
38  const bool fitIntercept = false);
39 
41  const arma::mat InitializeWeights();
42 
46  void Shuffle();
47 
57  static const arma::mat InitializeWeights(const size_t featureSize,
58  const size_t numClasses,
59  const bool fitIntercept = false);
60 
70  static void InitializeWeights(arma::mat &weights,
71  const size_t featureSize,
72  const size_t numClasses,
73  const bool fitIntercept = false);
74 
81  void GetGroundTruthMatrix(const arma::Row<size_t>& labels,
82  arma::sp_mat& groundTruth);
83 
95  void GetProbabilitiesMatrix(const arma::mat& parameters,
96  arma::mat& probabilities,
97  const size_t start,
98  const size_t batchSize) const;
99 
109  double Evaluate(const arma::mat& parameters) const;
110 
122  double Evaluate(const arma::mat& parameters,
123  const size_t start,
124  const size_t batchSize = 1) const;
125 
135  void Gradient(const arma::mat& parameters, arma::mat& gradient) const;
136 
148  void Gradient(const arma::mat& parameters,
149  const size_t start,
150  arma::mat& gradient,
151  const size_t batchSize = 1) const;
152 
162  void PartialGradient(const arma::mat& parameters,
163  size_t j,
164  arma::sp_mat& gradient) const;
165 
167  const arma::mat& GetInitialPoint() const { return initialPoint; }
168 
170  size_t NumClasses() const { return numClasses; }
171 
173  size_t NumFeatures() const
174  {
175  return initialPoint.n_cols;
176  }
180  size_t NumFunctions() const { return data.n_cols; }
181 
183  double& Lambda() { return lambda; }
185  double Lambda() const { return lambda; }
186 
188  bool FitIntercept() const { return fitIntercept; }
189 
190  private:
192  arma::mat data;
194  arma::sp_mat groundTruth;
196  arma::mat initialPoint;
198  size_t numClasses;
200  double lambda;
202  bool fitIntercept;
203 };
204 
205 } // namespace regression
206 } // namespace mlpack
207 
208 #endif
double Lambda() const
Gets the regularization parameter.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Gradient(const arma::mat &parameters, arma::mat &gradient) const
Evaluates the gradient values of the objective function given the current set of parameters.
double & Lambda()
Sets the regularization parameter.
size_t NumClasses() const
Gets the number of classes.
size_t NumFeatures() const
Gets the features size of the training data.
SoftmaxRegressionFunction(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda=0.0001, const bool fitIntercept=false)
Construct the Softmax Regression objective function with the given parameters.
void GetProbabilitiesMatrix(const arma::mat &parameters, arma::mat &probabilities, const size_t start, const size_t batchSize) const
Evaluate the probabilities matrix with the passed parameters.
const arma::mat InitializeWeights()
Initializes the parameters of the model to suitable values.
const arma::mat & GetInitialPoint() const
Return the initial point for the optimization.
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
void GetGroundTruthMatrix(const arma::Row< size_t > &labels, arma::sp_mat &groundTruth)
Constructs the ground truth label matrix with the passed labels.
void PartialGradient(const arma::mat &parameters, size_t j, arma::sp_mat &gradient) const
Evaluates the gradient values of the objective function given the current set of parameters for a sin...
double Evaluate(const arma::mat &parameters) const
Evaluates the objective function of the softmax regression model using the given parameters.