decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
18 #include "information_gain.hpp"
22 #include "all_dimension_select.hpp"
23 #include <type_traits>
24 
25 namespace mlpack {
26 namespace tree {
27 
35 template<typename FitnessFunction = GiniGain,
36  template<typename> class NumericSplitType = BestBinaryNumericSplit,
37  template<typename> class CategoricalSplitType = AllCategoricalSplit,
38  typename DimensionSelectionType = AllDimensionSelect,
39  bool NoRecursion = false>
40 class DecisionTree :
41  public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
42  public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
43 {
44  public:
46  typedef NumericSplitType<FitnessFunction> NumericSplit;
48  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
50  typedef DimensionSelectionType DimensionSelection;
51 
69  template<typename MatType, typename LabelsType>
70  DecisionTree(MatType data,
71  const data::DatasetInfo& datasetInfo,
72  LabelsType labels,
73  const size_t numClasses,
74  const size_t minimumLeafSize = 10,
75  const double minimumGainSplit = 1e-7,
76  const size_t maximumDepth = 0,
77  DimensionSelectionType dimensionSelector =
78  DimensionSelectionType());
79 
96  template<typename MatType, typename LabelsType>
97  DecisionTree(MatType data,
98  LabelsType labels,
99  const size_t numClasses,
100  const size_t minimumLeafSize = 10,
101  const double minimumGainSplit = 1e-7,
102  const size_t maximumDepth = 0,
103  DimensionSelectionType dimensionSelector =
104  DimensionSelectionType());
105 
125  template<typename MatType, typename LabelsType, typename WeightsType>
126  DecisionTree(
127  MatType data,
128  const data::DatasetInfo& datasetInfo,
129  LabelsType labels,
130  const size_t numClasses,
131  WeightsType weights,
132  const size_t minimumLeafSize = 10,
133  const double minimumGainSplit = 1e-7,
134  const size_t maximumDepth = 0,
135  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
136  const std::enable_if_t<arma::is_arma_type<
137  typename std::remove_reference<WeightsType>::type>::value>* = 0);
138 
158  template<typename MatType, typename LabelsType, typename WeightsType>
159  DecisionTree(
160  const DecisionTree& other,
161  MatType data,
162  const data::DatasetInfo& datasetInfo,
163  LabelsType labels,
164  const size_t numClasses,
165  WeightsType weights,
166  const size_t minimumLeafSize = 10,
167  const double minimumGainSplit = 1e-7,
168  const std::enable_if_t<arma::is_arma_type<
169  typename std::remove_reference<WeightsType>::type>::value>* = 0);
188  template<typename MatType, typename LabelsType, typename WeightsType>
189  DecisionTree(
190  MatType data,
191  LabelsType labels,
192  const size_t numClasses,
193  WeightsType weights,
194  const size_t minimumLeafSize = 10,
195  const double minimumGainSplit = 1e-7,
196  const size_t maximumDepth = 0,
197  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
198  const std::enable_if_t<arma::is_arma_type<
199  typename std::remove_reference<WeightsType>::type>::value>* = 0);
200 
219  template<typename MatType, typename LabelsType, typename WeightsType>
220  DecisionTree(
221  const DecisionTree& other,
222  MatType data,
223  LabelsType labels,
224  const size_t numClasses,
225  WeightsType weights,
226  const size_t minimumLeafSize = 10,
227  const double minimumGainSplit = 1e-7,
228  const size_t maximumDepth = 0,
229  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
230  const std::enable_if_t<arma::is_arma_type<
231  typename std::remove_reference<WeightsType>::type>::value>* = 0);
232 
239  DecisionTree(const size_t numClasses = 1);
240 
247  DecisionTree(const DecisionTree& other);
248 
254  DecisionTree(DecisionTree&& other);
255 
262  DecisionTree& operator=(const DecisionTree& other);
263 
270 
274  ~DecisionTree();
275 
295  template<typename MatType, typename LabelsType>
296  double Train(MatType data,
297  const data::DatasetInfo& datasetInfo,
298  LabelsType labels,
299  const size_t numClasses,
300  const size_t minimumLeafSize = 10,
301  const double minimumGainSplit = 1e-7,
302  const size_t maximumDepth = 0,
303  DimensionSelectionType dimensionSelector =
304  DimensionSelectionType());
305 
323  template<typename MatType, typename LabelsType>
324  double Train(MatType data,
325  LabelsType labels,
326  const size_t numClasses,
327  const size_t minimumLeafSize = 10,
328  const double minimumGainSplit = 1e-7,
329  const size_t maximumDepth = 0,
330  DimensionSelectionType dimensionSelector =
331  DimensionSelectionType());
332 
354  template<typename MatType, typename LabelsType, typename WeightsType>
355  double Train(MatType data,
356  const data::DatasetInfo& datasetInfo,
357  LabelsType labels,
358  const size_t numClasses,
359  WeightsType weights,
360  const size_t minimumLeafSize = 10,
361  const double minimumGainSplit = 1e-7,
362  const size_t maximumDepth = 0,
363  DimensionSelectionType dimensionSelector =
364  DimensionSelectionType(),
365  const std::enable_if_t<arma::is_arma_type<typename
366  std::remove_reference<WeightsType>::type>::value>* = 0);
367 
387  template<typename MatType, typename LabelsType, typename WeightsType>
388  double Train(MatType data,
389  LabelsType labels,
390  const size_t numClasses,
391  WeightsType weights,
392  const size_t minimumLeafSize = 10,
393  const double minimumGainSplit = 1e-7,
394  const size_t maximumDepth = 0,
395  DimensionSelectionType dimensionSelector =
396  DimensionSelectionType(),
397  const std::enable_if_t<arma::is_arma_type<typename
398  std::remove_reference<WeightsType>::type>::value>* = 0);
399 
406  template<typename VecType>
407  size_t Classify(const VecType& point) const;
408 
418  template<typename VecType>
419  void Classify(const VecType& point,
420  size_t& prediction,
421  arma::vec& probabilities) const;
422 
430  template<typename MatType>
431  void Classify(const MatType& data,
432  arma::Row<size_t>& predictions) const;
433 
444  template<typename MatType>
445  void Classify(const MatType& data,
446  arma::Row<size_t>& predictions,
447  arma::mat& probabilities) const;
448 
452  template<typename Archive>
453  void serialize(Archive& ar, const uint32_t /* version */);
454 
456  size_t NumChildren() const { return children.size(); }
457 
459  const DecisionTree& Child(const size_t i) const { return *children[i]; }
461  DecisionTree& Child(const size_t i) { return *children[i]; }
462 
465  size_t SplitDimension() const { return splitDimension; }
466 
474  template<typename VecType>
475  size_t CalculateDirection(const VecType& point) const;
476 
480  size_t NumClasses() const;
481 
482  private:
484  std::vector<DecisionTree*> children;
486  size_t splitDimension;
489  size_t dimensionTypeOrMajorityClass;
497  arma::vec classProbabilities;
498 
502  typedef typename NumericSplit::AuxiliarySplitInfo
503  NumericAuxiliarySplitInfo;
504  typedef typename CategoricalSplit::AuxiliarySplitInfo
505  CategoricalAuxiliarySplitInfo;
506 
510  template<bool UseWeights, typename RowType, typename WeightsRowType>
511  void CalculateClassProbabilities(const RowType& labels,
512  const size_t numClasses,
513  const WeightsRowType& weights);
514 
532  template<bool UseWeights, typename MatType>
533  double Train(MatType& data,
534  const size_t begin,
535  const size_t count,
536  const data::DatasetInfo& datasetInfo,
537  arma::Row<size_t>& labels,
538  const size_t numClasses,
539  arma::rowvec& weights,
540  const size_t minimumLeafSize,
541  const double minimumGainSplit,
542  const size_t maximumDepth,
543  DimensionSelectionType& dimensionSelector);
544 
561  template<bool UseWeights, typename MatType>
562  double Train(MatType& data,
563  const size_t begin,
564  const size_t count,
565  arma::Row<size_t>& labels,
566  const size_t numClasses,
567  arma::rowvec& weights,
568  const size_t minimumLeafSize,
569  const double minimumGainSplit,
570  const size_t maximumDepth,
571  DimensionSelectionType& dimensionSelector);
572 };
573 
577 template<typename FitnessFunction = GiniGain,
578  template<typename> class NumericSplitType = BestBinaryNumericSplit,
579  template<typename> class CategoricalSplitType = AllCategoricalSplit,
580  typename DimensionSelectType = AllDimensionSelect>
581 using DecisionStump = DecisionTree<FitnessFunction,
582  NumericSplitType,
583  CategoricalSplitType,
584  DimensionSelectType,
585  false>;
586 
596 } // namespace tree
597 } // namespace mlpack
598 
599 // Include implementation.
600 #include "decision_tree_impl.hpp"
601 
602 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
Linear algebra utility functions, generally performed on matrices or vectors.
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
The standard information gain criterion, used for calculating gain in decision trees.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.