hoeffding_tree_model.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_HOEFFDING_TREE_HOEFFDING_TREE_MODEL_HPP
13 #define MLPACK_METHODS_HOEFFDING_TREE_HOEFFDING_TREE_MODEL_HPP
14 
15 #include "hoeffding_tree.hpp"
16 #include "binary_numeric_split.hpp"
17 #include "information_gain.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
28 {
29  public:
31  enum TreeType
32  {
37  };
38 
43  typedef HoeffdingTree<GiniImpurity, BinaryDoubleNumericSplit,
49  typedef HoeffdingTree<HoeffdingInformationGain, BinaryDoubleNumericSplit,
51 
60 
67 
74 
81 
88 
93 
114  void BuildModel(const arma::mat& dataset,
115  const data::DatasetInfo& datasetInfo,
116  const arma::Row<size_t>& labels,
117  const size_t numClasses,
118  const bool batchTraining,
119  const double successProbability,
120  const size_t maxSamples,
121  const size_t checkInterval,
122  const size_t minSamples,
123  const size_t bins,
124  const size_t observationsBeforeBinning);
125 
134  void Train(const arma::mat& dataset,
135  const arma::Row<size_t>& labels,
136  const bool batchTraining);
137 
145  void Classify(const arma::mat& dataset,
146  arma::Row<size_t>& predictions) const;
147 
156  void Classify(const arma::mat& dataset,
157  arma::Row<size_t>& predictions,
158  arma::rowvec& probabilities) const;
159 
163  size_t NumNodes() const;
164 
168  template<typename Archive>
169  void serialize(Archive& ar, const uint32_t /* version */)
170  {
171  // Clear memory if needed.
172  if (cereal::is_loading<Archive>())
173  {
174  delete giniHoeffdingTree;
175  delete giniBinaryTree;
176  delete infoHoeffdingTree;
177  delete infoBinaryTree;
178 
179  giniHoeffdingTree = NULL;
180  giniBinaryTree = NULL;
181  infoHoeffdingTree = NULL;
182  infoBinaryTree = NULL;
183  }
184 
185  ar(CEREAL_NVP(type));
186 
187  // Fake dataset info may be needed to create fake trees.
188  data::DatasetInfo info;
189  if (type == GINI_HOEFFDING)
190  ar(CEREAL_POINTER(giniHoeffdingTree));
191  else if (type == GINI_BINARY)
192  ar(CEREAL_POINTER(giniBinaryTree));
193  else if (type == INFO_HOEFFDING)
194  ar(CEREAL_POINTER(infoHoeffdingTree));
195  else if (type == INFO_BINARY)
196  ar(CEREAL_POINTER(infoBinaryTree));
197  }
198 
199  private:
201  TreeType type;
202 
205  GiniHoeffdingTreeType* giniHoeffdingTree;
206 
209  GiniBinaryTreeType* giniBinaryTree;
210 
213  InfoHoeffdingTreeType* infoHoeffdingTree;
214 
217  InfoBinaryTreeType* infoBinaryTree;
218 };
219 
220 } // namespace tree
221 } // namespace mlpack
222 
223 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
void serialize(Archive &ar, const uint32_t)
Serialize the model.
~HoeffdingTreeModel()
Clean up the given model.
HoeffdingTreeModel & operator=(const HoeffdingTreeModel &other)
Copy the Hoeffding tree model from the given other model.
HoeffdingTree< GiniImpurity, BinaryDoubleNumericSplit, HoeffdingCategoricalSplit > GiniBinaryTreeType
Convenience typedef for GINI_BINARY tree type.
void Classify(const arma::mat &dataset, arma::Row< size_t > &predictions) const
Using the model, classify the given test points.
HoeffdingTree< HoeffdingInformationGain, BinaryDoubleNumericSplit, HoeffdingCategoricalSplit > InfoBinaryTreeType
Convenience typedef for INFO_BINARY tree type.
The HoeffdingTree object represents all of the necessary information for a Hoeffding-bound-based deci...
Linear algebra utility functions, generally performed on matrices or vectors.
HoeffdingTreeModel(const TreeType &type=GINI_HOEFFDING)
Construct the Hoeffding tree model, but don&#39;t initialize any tree.
HoeffdingTree< HoeffdingInformationGain, HoeffdingDoubleNumericSplit, HoeffdingCategoricalSplit > InfoHoeffdingTreeType
Convenience typedef for INFO_HOEFFDING tree type.
HoeffdingTree< GiniImpurity, HoeffdingDoubleNumericSplit, HoeffdingCategoricalSplit > GiniHoeffdingTreeType
Convenience typedef for GINI_HOEFFDING tree type.
void BuildModel(const arma::mat &dataset, const data::DatasetInfo &datasetInfo, const arma::Row< size_t > &labels, const size_t numClasses, const bool batchTraining, const double successProbability, const size_t maxSamples, const size_t checkInterval, const size_t minSamples, const size_t bins, const size_t observationsBeforeBinning)
Train the model on the given dataset with the given labels.
BinaryNumericSplit< FitnessFunction, double > BinaryDoubleNumericSplit
TreeType
This enumerates the four types of trees we can hold.
void Train(const arma::mat &dataset, const arma::Row< size_t > &labels, const bool batchTraining)
Train in streaming mode on the given dataset.
This class is a serializable Hoeffding tree model that can hold four different types of Hoeffding tre...
size_t NumNodes() const
Get the number of nodes in the tree.
This is the standard Hoeffding-bound categorical feature proposed in the paper below: ...
HoeffdingNumericSplit< FitnessFunction, double > HoeffdingDoubleNumericSplit
Convenience typedef.
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.