13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP 14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP 23 #include <type_traits> 36 template<
typename FitnessFunction = MSEGain,
37 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
38 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
39 typename DimensionSelectionType = AllDimensionSelect,
40 bool NoRecursion =
false>
42 public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
43 public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
74 template<
typename MatType,
typename ResponsesType>
77 ResponsesType responses,
78 const size_t minimumLeafSize = 10,
79 const double minimumGainSplit = 1e-7,
80 const size_t maximumDepth = 0,
81 DimensionSelectionType dimensionSelector =
82 DimensionSelectionType());
99 template<
typename MatType,
typename ResponsesType>
101 ResponsesType responses,
102 const size_t minimumLeafSize = 10,
103 const double minimumGainSplit = 1e-7,
104 const size_t maximumDepth = 0,
105 DimensionSelectionType dimensionSelector =
106 DimensionSelectionType());
126 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
130 ResponsesType responses,
132 const size_t minimumLeafSize = 10,
133 const double minimumGainSplit = 1e-7,
134 const size_t maximumDepth = 0,
135 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
137 typename std::remove_reference<WeightsType>::type>::value>* = 0);
156 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
159 ResponsesType responses,
161 const size_t minimumLeafSize = 10,
162 const double minimumGainSplit = 1e-7,
163 const size_t maximumDepth = 0,
164 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
166 typename std::remove_reference<WeightsType>::type>::value>* = 0);
186 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
191 ResponsesType responses,
193 const size_t minimumLeafSize = 10,
194 const double minimumGainSplit = 1e-7,
196 typename std::remove_reference<WeightsType>::type>::value>* = 0);
215 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
219 ResponsesType responses,
221 const size_t minimumLeafSize = 10,
222 const double minimumGainSplit = 1e-7,
223 const size_t maximumDepth = 0,
224 DimensionSelectionType dimensionSelector = DimensionSelectionType(),
226 typename std::remove_reference<WeightsType>::type>::value>* = 0);
283 template<
typename MatType,
typename ResponsesType>
284 double Train(MatType data,
286 ResponsesType responses,
287 const size_t minimumLeafSize = 10,
288 const double minimumGainSplit = 1e-7,
289 const size_t maximumDepth = 0,
290 DimensionSelectionType dimensionSelector =
291 DimensionSelectionType(),
292 FitnessFunction fitnessFunction = FitnessFunction());
312 template<
typename MatType,
typename ResponsesType>
313 double Train(MatType data,
314 ResponsesType responses,
315 const size_t minimumLeafSize = 10,
316 const double minimumGainSplit = 1e-7,
317 const size_t maximumDepth = 0,
318 DimensionSelectionType dimensionSelector =
319 DimensionSelectionType(),
320 FitnessFunction fitnessFunction = FitnessFunction());
344 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
345 double Train(MatType data,
347 ResponsesType responses,
349 const size_t minimumLeafSize = 10,
350 const double minimumGainSplit = 1e-7,
351 const size_t maximumDepth = 0,
352 DimensionSelectionType dimensionSelector =
353 DimensionSelectionType(),
354 FitnessFunction fitnessFunction = FitnessFunction(),
356 std::remove_reference<WeightsType>::type>::value>* = 0);
378 template<
typename MatType,
typename ResponsesType,
typename WeightsType>
379 double Train(MatType data,
380 ResponsesType responses,
382 const size_t minimumLeafSize = 10,
383 const double minimumGainSplit = 1e-7,
384 const size_t maximumDepth = 0,
385 DimensionSelectionType dimensionSelector =
386 DimensionSelectionType(),
387 FitnessFunction fitnessFunction = FitnessFunction(),
389 std::remove_reference<WeightsType>::type>::value>* = 0);
397 template<
typename VecType>
398 double Predict(
const VecType& point)
const;
407 template<
typename MatType>
408 void Predict(
const MatType& data,
409 arma::Row<double>& predictions)
const;
414 template<
typename Archive>
415 void serialize(Archive& ar,
const uint32_t );
442 template<
typename VecType>
447 std::vector<DecisionTreeRegressor*> children;
449 size_t splitDimension;
452 size_t dimensionType;
460 double splitPointOrPrediction;
465 typedef typename NumericSplit::AuxiliarySplitInfo
466 NumericAuxiliarySplitInfo;
467 typedef typename CategoricalSplit::AuxiliarySplitInfo
468 CategoricalAuxiliarySplitInfo;
488 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
489 double Train(MatType& data,
493 ResponsesType& responses,
494 arma::rowvec& weights,
495 const size_t minimumLeafSize,
496 const double minimumGainSplit,
497 const size_t maximumDepth,
498 DimensionSelectionType& dimensionSelector,
499 FitnessFunction fitnessFunction = FitnessFunction());
518 template<
bool UseWeights,
typename MatType,
typename ResponsesType>
519 double Train(MatType& data,
522 ResponsesType& responses,
523 arma::rowvec& weights,
524 const size_t minimumLeafSize,
525 const double minimumGainSplit,
526 const size_t maximumDepth,
527 DimensionSelectionType& dimensionSelector,
528 FitnessFunction fitnessFunction = FitnessFunction());
536 #include "decision_tree_regressor_impl.hpp" void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
double Predict(const VecType &point) const
Make prediction for the given point, using the entire tree.
typename enable_if< B, T >::type enable_if_t
Linear algebra utility functions, generally performed on matrices or vectors.
const DecisionTreeRegressor & Child(const size_t i) const
Get the child of the given index.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
DecisionTreeRegressor & operator=(const DecisionTreeRegressor &other)
Copy another tree.
The core includes that mlpack expects; standard C++ includes and Armadillo.
~DecisionTreeRegressor()
Clean up memory.
double Train(MatType data, const data::DatasetInfo &datasetInfo, ResponsesType responses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), FitnessFunction fitnessFunction=FitnessFunction())
Train the decision tree on the given data.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
size_t NumLeaves() const
Get the number of leaves in the tree.
DecisionTreeRegressor & Child(const size_t i)
Modify the child of the given index (be careful!).
This class implements a generic decision tree learner.
DecisionTreeRegressor()
Construct a decision tree without training it.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
size_t NumChildren() const
Get the number of children.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.