adaboost_model.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
13 #define MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 // Use forward declaration instead of include to accelerate compilation.
18 class AdaBoost;
19 
20 namespace mlpack {
21 namespace adaboost {
22 
27 {
28  public:
30  {
33  };
34 
35  private:
37  arma::Col<size_t> mappings;
39  size_t weakLearnerType;
45  size_t dimensionality;
46 
47  public:
49  AdaBoostModel();
50 
52  AdaBoostModel(const arma::Col<size_t>& mappings,
53  const size_t weakLearnerType);
54 
56  AdaBoostModel(const AdaBoostModel& other);
57 
60 
62  AdaBoostModel& operator=(const AdaBoostModel& other);
63 
66 
69 
71  const arma::Col<size_t>& Mappings() const { return mappings; }
73  arma::Col<size_t>& Mappings() { return mappings; }
74 
76  size_t WeakLearnerType() const { return weakLearnerType; }
78  size_t& WeakLearnerType() { return weakLearnerType; }
79 
81  size_t Dimensionality() const { return dimensionality; }
83  size_t& Dimensionality() { return dimensionality; }
84 
86  void Train(const arma::mat& data,
87  const arma::Row<size_t>& labels,
88  const size_t numClasses,
89  const size_t iterations,
90  const double tolerance);
91 
93  void Classify(const arma::mat& testData,
94  arma::Row<size_t>& predictions);
95 
97  void Classify(const arma::mat& testData,
98  arma::Row<size_t>& predictions,
99  arma::mat& probabilities);
100 
102  template<typename Archive>
103  void serialize(Archive& ar, const uint32_t /* version */)
104  {
105  if (cereal::is_loading<Archive>())
106  {
107  if (dsBoost)
108  delete dsBoost;
109  if (pBoost)
110  delete pBoost;
111 
112  dsBoost = NULL;
113  pBoost = NULL;
114  }
115 
116  ar(CEREAL_NVP(mappings));
117  ar(CEREAL_NVP(weakLearnerType));
118  if (weakLearnerType == WeakLearnerTypes::DECISION_STUMP)
119  ar(CEREAL_POINTER(dsBoost));
120  else if (weakLearnerType == WeakLearnerTypes::PERCEPTRON)
121  ar(CEREAL_POINTER(pBoost));
122  ar(CEREAL_NVP(dimensionality));
123  }
124 };
125 
126 } // namespace adaboost
127 } // namespace mlpack
128 
129 #endif
~AdaBoostModel()
Clean up memory.
void Classify(const arma::mat &testData, arma::Row< size_t > &predictions)
Classify test points.
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the model.
The AdaBoost class.
Definition: adaboost.hpp:81
size_t & Dimensionality()
Modify the dimensionality of the model.
The model to save to disk.
arma::Col< size_t > & Mappings()
Modify the mappings.
void Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t iterations, const double tolerance)
Train the model, treat the data is all of the numeric type.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
size_t WeakLearnerType() const
Get the weak learner type.
AdaBoostModel()
Create an empty AdaBoost model.
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
const arma::Col< size_t > & Mappings() const
Get the mappings.
AdaBoostModel & operator=(const AdaBoostModel &other)
Copy assignment operator.
size_t Dimensionality() const
Get the dimensionality of the model.
size_t & WeakLearnerType()
Modify the weak learner type.