13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP 23 #include <type_traits> 35 template<
typename FitnessFunction = GiniGain,
36 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
37 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
38 typename DimensionSelectionType = AllDimensionSelect,
39 bool NoRecursion =
false>
41 public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
42 public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
69 template<
typename MatType,
typename LabelsType>
73 const size_t numClasses,
74 const size_t minimumLeafSize = 10,
75 const double minimumGainSplit = 1e-7,
76 const size_t maximumDepth = 0,
77 DimensionSelectionType dimensionSelector =
78 DimensionSelectionType());
96 template<
typename MatType,
typename LabelsType>
99 const size_t numClasses,
100 const size_t minimumLeafSize = 10,
101 const double minimumGainSplit = 1e-7,
102 const size_t maximumDepth = 0,
103 DimensionSelectionType dimensionSelector =
104 DimensionSelectionType());
125 template<
typename MatType,
typename LabelsType,
typename WeightsType>
130 const size_t numClasses,
132 const size_t minimumLeafSize = 10,
133 const double minimumGainSplit = 1e-7,
134 const size_t maximumDepth = 0,
135 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
137 typename std::remove_reference<WeightsType>::type>::value>* = 0);
158 template<
typename MatType,
typename LabelsType,
typename WeightsType>
164 const size_t numClasses,
166 const size_t minimumLeafSize = 10,
167 const double minimumGainSplit = 1e-7,
169 typename std::remove_reference<WeightsType>::type>::value>* = 0);
188 template<
typename MatType,
typename LabelsType,
typename WeightsType>
192 const size_t numClasses,
194 const size_t minimumLeafSize = 10,
195 const double minimumGainSplit = 1e-7,
196 const size_t maximumDepth = 0,
197 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
199 typename std::remove_reference<WeightsType>::type>::value>* = 0);
219 template<
typename MatType,
typename LabelsType,
typename WeightsType>
224 const size_t numClasses,
226 const size_t minimumLeafSize = 10,
227 const double minimumGainSplit = 1e-7,
228 const size_t maximumDepth = 0,
229 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
231 typename std::remove_reference<WeightsType>::type>::value>* = 0);
295 template<
typename MatType,
typename LabelsType>
296 double Train(MatType data,
299 const size_t numClasses,
300 const size_t minimumLeafSize = 10,
301 const double minimumGainSplit = 1e-7,
302 const size_t maximumDepth = 0,
303 DimensionSelectionType dimensionSelector =
304 DimensionSelectionType());
323 template<
typename MatType,
typename LabelsType>
324 double Train(MatType data,
326 const size_t numClasses,
327 const size_t minimumLeafSize = 10,
328 const double minimumGainSplit = 1e-7,
329 const size_t maximumDepth = 0,
330 DimensionSelectionType dimensionSelector =
331 DimensionSelectionType());
354 template<
typename MatType,
typename LabelsType,
typename WeightsType>
355 double Train(MatType data,
358 const size_t numClasses,
360 const size_t minimumLeafSize = 10,
361 const double minimumGainSplit = 1e-7,
362 const size_t maximumDepth = 0,
363 DimensionSelectionType dimensionSelector =
364 DimensionSelectionType(),
366 std::remove_reference<WeightsType>::type>::value>* = 0);
387 template<
typename MatType,
typename LabelsType,
typename WeightsType>
388 double Train(MatType data,
390 const size_t numClasses,
392 const size_t minimumLeafSize = 10,
393 const double minimumGainSplit = 1e-7,
394 const size_t maximumDepth = 0,
395 DimensionSelectionType dimensionSelector =
396 DimensionSelectionType(),
398 std::remove_reference<WeightsType>::type>::value>* = 0);
406 template<
typename VecType>
407 size_t Classify(
const VecType& point)
const;
418 template<
typename VecType>
421 arma::vec& probabilities)
const;
430 template<
typename MatType>
432 arma::Row<size_t>& predictions)
const;
444 template<
typename MatType>
446 arma::Row<size_t>& predictions,
447 arma::mat& probabilities)
const;
452 template<
typename Archive>
453 void serialize(Archive& ar,
const uint32_t );
474 template<
typename VecType>
484 std::vector<DecisionTree*> children;
486 size_t splitDimension;
489 size_t dimensionTypeOrMajorityClass;
497 arma::vec classProbabilities;
502 typedef typename NumericSplit::AuxiliarySplitInfo
503 NumericAuxiliarySplitInfo;
504 typedef typename CategoricalSplit::AuxiliarySplitInfo
505 CategoricalAuxiliarySplitInfo;
510 template<
bool UseWeights,
typename RowType,
typename WeightsRowType>
511 void CalculateClassProbabilities(
const RowType& labels,
512 const size_t numClasses,
513 const WeightsRowType& weights);
532 template<
bool UseWeights,
typename MatType>
533 double Train(MatType& data,
537 arma::Row<size_t>& labels,
538 const size_t numClasses,
539 arma::rowvec& weights,
540 const size_t minimumLeafSize,
541 const double minimumGainSplit,
542 const size_t maximumDepth,
543 DimensionSelectionType& dimensionSelector);
561 template<
bool UseWeights,
typename MatType>
562 double Train(MatType& data,
565 arma::Row<size_t>& labels,
566 const size_t numClasses,
567 arma::rowvec& weights,
568 const size_t minimumLeafSize,
569 const double minimumGainSplit,
570 const size_t maximumDepth,
571 DimensionSelectionType& dimensionSelector);
577 template<
typename FitnessFunction =
GiniGain,
583 CategoricalSplitType,
600 #include "decision_tree_impl.hpp" size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Linear algebra utility functions, generally performed on matrices or vectors.
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
DecisionTree< InformationGain, BestBinaryNumericSplit, AllCategoricalSplit, AllDimensionSelect, true > ID3DecisionStump
Convenience typedef for ID3 decision stumps (single level decision trees made with the ID3 algorithm)...
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.