hoeffding_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_HOEFFDING_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 #include "gini_impurity.hpp"
21 
22 namespace mlpack {
23 namespace tree {
24 
55 template<typename FitnessFunction = GiniImpurity,
56  template<typename> class NumericSplitType =
58  template<typename> class CategoricalSplitType =
59  HoeffdingCategoricalSplit
60 >
62 {
63  public:
65  typedef NumericSplitType<FitnessFunction> NumericSplit;
67  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
68 
93  template<typename MatType>
94  HoeffdingTree(const MatType& data,
95  const data::DatasetInfo& datasetInfo,
96  const arma::Row<size_t>& labels,
97  const size_t numClasses,
98  const bool batchTraining = true,
99  const double successProbability = 0.95,
100  const size_t maxSamples = 0,
101  const size_t checkInterval = 100,
102  const size_t minSamples = 100,
103  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
104  = CategoricalSplitType<FitnessFunction>(0, 0),
105  const NumericSplitType<FitnessFunction>& numericSplitIn =
106  NumericSplitType<FitnessFunction>(0));
107 
130  HoeffdingTree(const data::DatasetInfo& datasetInfo,
131  const size_t numClasses,
132  const double successProbability = 0.95,
133  const size_t maxSamples = 0,
134  const size_t checkInterval = 100,
135  const size_t minSamples = 100,
136  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn
137  = CategoricalSplitType<FitnessFunction>(0, 0),
138  const NumericSplitType<FitnessFunction>& numericSplitIn =
139  NumericSplitType<FitnessFunction>(0),
140  std::unordered_map<size_t, std::pair<size_t, size_t>>*
141  dimensionMappings = NULL,
142  const bool copyDatasetInfo = true);
143 
148  HoeffdingTree();
149 
156  HoeffdingTree(const HoeffdingTree& other);
157 
163  HoeffdingTree(HoeffdingTree&& other);
164 
170  HoeffdingTree& operator=(const HoeffdingTree& other);
171 
178 
182  ~HoeffdingTree();
183 
201  template<typename MatType>
202  void Train(const MatType& data,
203  const arma::Row<size_t>& labels,
204  const bool batchTraining = true,
205  const bool resetTree = false,
206  const size_t numClasses = 0);
207 
223  template<typename MatType>
224  void Train(const MatType& data,
225  const data::DatasetInfo& info,
226  const arma::Row<size_t>& labels,
227  const bool batchTraining = true,
228  const size_t numClasses = 0);
229 
237  template<typename VecType>
238  void Train(const VecType& point, const size_t label);
239 
245  size_t SplitCheck();
246 
248  size_t SplitDimension() const { return splitDimension; }
249 
251  size_t MajorityClass() const { return majorityClass; }
253  size_t& MajorityClass() { return majorityClass; }
254 
256  double MajorityProbability() const { return majorityProbability; }
258  double& MajorityProbability() { return majorityProbability; }
259 
261  size_t NumChildren() const { return children.size(); }
262 
264  const HoeffdingTree& Child(const size_t i) const { return *children[i]; }
266  HoeffdingTree& Child(const size_t i) { return *children[i]; }
267 
269  double SuccessProbability() const { return successProbability; }
271  void SuccessProbability(const double successProbability);
272 
274  size_t MinSamples() const { return minSamples; }
276  void MinSamples(const size_t minSamples);
277 
279  size_t MaxSamples() const { return maxSamples; }
281  void MaxSamples(const size_t maxSamples);
282 
284  size_t CheckInterval() const { return checkInterval; }
286  void CheckInterval(const size_t checkInterval);
287 
295  template<typename VecType>
296  size_t CalculateDirection(const VecType& point) const;
297 
305  template<typename VecType>
306  size_t Classify(const VecType& point) const;
307 
309  size_t NumDescendants() const;
310 
322  template<typename VecType>
323  void Classify(const VecType& point, size_t& prediction, double& probability)
324  const;
325 
333  template<typename MatType>
334  void Classify(const MatType& data, arma::Row<size_t>& predictions) const;
335 
347  template<typename MatType>
348  void Classify(const MatType& data,
349  arma::Row<size_t>& predictions,
350  arma::rowvec& probabilities) const;
351 
355  void CreateChildren();
356 
358  template<typename Archive>
359  void serialize(Archive& ar, const uint32_t /* version */);
360 
361  private:
362  // We need to keep some information for before we have split.
363 
365  std::vector<NumericSplitType<FitnessFunction>> numericSplits;
367  std::vector<CategoricalSplitType<FitnessFunction>> categoricalSplits;
368 
370  std::unordered_map<size_t, std::pair<size_t, size_t>>* dimensionMappings;
372  bool ownsMappings;
373 
375  size_t numSamples;
377  size_t numClasses;
379  size_t maxSamples;
381  size_t checkInterval;
383  size_t minSamples;
385  const data::DatasetInfo* datasetInfo;
387  bool ownsInfo;
389  double successProbability;
390 
391  // And we need to keep some information for after we have split.
392 
394  size_t splitDimension;
396  size_t majorityClass;
399  double majorityProbability;
401  typename CategoricalSplitType<FitnessFunction>::SplitInfo categoricalSplit;
403  typename NumericSplitType<FitnessFunction>::SplitInfo numericSplit;
405  std::vector<HoeffdingTree*> children;
406 
411  template<typename MatType>
412  void TrainInternal(const MatType& data,
413  const arma::Row<size_t>& labels,
414  const bool batchTraining);
415 
419  void ResetTree(
420  const CategoricalSplitType<FitnessFunction>& categoricalSplitIn =
421  CategoricalSplitType<FitnessFunction>(0, 0),
422  const NumericSplitType<FitnessFunction>& numericSplitIn =
423  NumericSplitType<FitnessFunction>(0));
424 };
425 
426 } // namespace tree
427 } // namespace mlpack
428 
429 #include "hoeffding_tree_impl.hpp"
430 
431 #endif
HoeffdingTree()
Construct a Hoeffding tree with no data and no information.
const HoeffdingTree & Child(const size_t i) const
Get a child.
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
~HoeffdingTree()
Clean up memory.
The HoeffdingTree object represents all of the necessary information for a Hoeffding-bound-based deci...
size_t CheckInterval() const
Get the number of samples before a split check is performed.
Linear algebra utility functions, generally performed on matrices or vectors.
void Train(const MatType &data, const arma::Row< size_t > &labels, const bool batchTraining=true, const bool resetTree=false, const size_t numClasses=0)
Train on a set of points, either in streaming mode or in batch mode, with the given labels...
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
size_t Classify(const VecType &point) const
Classify the given point, using this node and the entire (sub)tree beneath it.
size_t MaxSamples() const
Get the maximum number of samples before a split is forced.
double SuccessProbability() const
Get the confidence required for a split.
void serialize(Archive &ar, const uint32_t)
Serialize the split.
void CreateChildren()
Given that this node should split, create the children.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
HoeffdingTree & operator=(const HoeffdingTree &other)
Copy assignment operator.
double MajorityProbability() const
Get the probability of the majority class (based on training samples).
size_t NumDescendants() const
Get the size of the Hoeffding Tree.
HoeffdingNumericSplit< FitnessFunction, double > HoeffdingDoubleNumericSplit
Convenience typedef.
size_t & MajorityClass()
Modify the majority class.
size_t SplitDimension() const
Get the splitting dimension (size_t(-1) if no split).
size_t MinSamples() const
Get the minimum number of samples for a split.
size_t SplitCheck()
Check if a split would satisfy the conditions of the Hoeffding bound with the node&#39;s specified succes...
HoeffdingTree & Child(const size_t i)
Modify a child.
size_t MajorityClass() const
Get the majority class.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
double & MajorityProbability()
Modify the probability of the majority class.