decision_tree_regressor.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_REGRESSOR_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "mad_gain.hpp"
18 #include "mse_gain.hpp"
22 #include "all_dimension_select.hpp"
23 #include <type_traits>
24 
25 
26 namespace mlpack {
27 namespace tree {
28 
36 template<typename FitnessFunction = MSEGain,
37  template<typename> class NumericSplitType = BestBinaryNumericSplit,
38  template<typename> class CategoricalSplitType = AllCategoricalSplit,
39  typename DimensionSelectionType = AllDimensionSelect,
40  bool NoRecursion = false>
42  public NumericSplitType<FitnessFunction>::AuxiliarySplitInfo,
43  public CategoricalSplitType<FitnessFunction>::AuxiliarySplitInfo
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
57 
74  template<typename MatType, typename ResponsesType>
75  DecisionTreeRegressor(MatType data,
76  const data::DatasetInfo& datasetInfo,
77  ResponsesType responses,
78  const size_t minimumLeafSize = 10,
79  const double minimumGainSplit = 1e-7,
80  const size_t maximumDepth = 0,
81  DimensionSelectionType dimensionSelector =
82  DimensionSelectionType());
83 
99  template<typename MatType, typename ResponsesType>
100  DecisionTreeRegressor(MatType data,
101  ResponsesType responses,
102  const size_t minimumLeafSize = 10,
103  const double minimumGainSplit = 1e-7,
104  const size_t maximumDepth = 0,
105  DimensionSelectionType dimensionSelector =
106  DimensionSelectionType());
107 
126  template<typename MatType, typename ResponsesType, typename WeightsType>
128  MatType data,
129  const data::DatasetInfo& datasetInfo,
130  ResponsesType responses,
131  WeightsType weights,
132  const size_t minimumLeafSize = 10,
133  const double minimumGainSplit = 1e-7,
134  const size_t maximumDepth = 0,
135  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
136  const std::enable_if_t<arma::is_arma_type<
137  typename std::remove_reference<WeightsType>::type>::value>* = 0);
138 
156  template<typename MatType, typename ResponsesType, typename WeightsType>
158  MatType data,
159  ResponsesType responses,
160  WeightsType weights,
161  const size_t minimumLeafSize = 10,
162  const double minimumGainSplit = 1e-7,
163  const size_t maximumDepth = 0,
164  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
165  const std::enable_if_t<arma::is_arma_type<
166  typename std::remove_reference<WeightsType>::type>::value>* = 0);
167 
186  template<typename MatType, typename ResponsesType, typename WeightsType>
188  const DecisionTreeRegressor& other,
189  MatType data,
190  const data::DatasetInfo& datasetInfo,
191  ResponsesType responses,
192  WeightsType weights,
193  const size_t minimumLeafSize = 10,
194  const double minimumGainSplit = 1e-7,
195  const std::enable_if_t<arma::is_arma_type<
196  typename std::remove_reference<WeightsType>::type>::value>* = 0);
197 
215  template<typename MatType, typename ResponsesType, typename WeightsType>
217  const DecisionTreeRegressor& other,
218  MatType data,
219  ResponsesType responses,
220  WeightsType weights,
221  const size_t minimumLeafSize = 10,
222  const double minimumGainSplit = 1e-7,
223  const size_t maximumDepth = 0,
224  DimensionSelectionType dimensionSelector = DimensionSelectionType(),
225  const std::enable_if_t<arma::is_arma_type<
226  typename std::remove_reference<WeightsType>::type>::value>* = 0);
227 
235 
242 
250 
257 
262 
283  template<typename MatType, typename ResponsesType>
284  double Train(MatType data,
285  const data::DatasetInfo& datasetInfo,
286  ResponsesType responses,
287  const size_t minimumLeafSize = 10,
288  const double minimumGainSplit = 1e-7,
289  const size_t maximumDepth = 0,
290  DimensionSelectionType dimensionSelector =
291  DimensionSelectionType(),
292  FitnessFunction fitnessFunction = FitnessFunction());
293 
312  template<typename MatType, typename ResponsesType>
313  double Train(MatType data,
314  ResponsesType responses,
315  const size_t minimumLeafSize = 10,
316  const double minimumGainSplit = 1e-7,
317  const size_t maximumDepth = 0,
318  DimensionSelectionType dimensionSelector =
319  DimensionSelectionType(),
320  FitnessFunction fitnessFunction = FitnessFunction());
321 
344  template<typename MatType, typename ResponsesType, typename WeightsType>
345  double Train(MatType data,
346  const data::DatasetInfo& datasetInfo,
347  ResponsesType responses,
348  WeightsType weights,
349  const size_t minimumLeafSize = 10,
350  const double minimumGainSplit = 1e-7,
351  const size_t maximumDepth = 0,
352  DimensionSelectionType dimensionSelector =
353  DimensionSelectionType(),
354  FitnessFunction fitnessFunction = FitnessFunction(),
355  const std::enable_if_t<arma::is_arma_type<typename
356  std::remove_reference<WeightsType>::type>::value>* = 0);
357 
378  template<typename MatType, typename ResponsesType, typename WeightsType>
379  double Train(MatType data,
380  ResponsesType responses,
381  WeightsType weights,
382  const size_t minimumLeafSize = 10,
383  const double minimumGainSplit = 1e-7,
384  const size_t maximumDepth = 0,
385  DimensionSelectionType dimensionSelector =
386  DimensionSelectionType(),
387  FitnessFunction fitnessFunction = FitnessFunction(),
388  const std::enable_if_t<arma::is_arma_type<typename
389  std::remove_reference<WeightsType>::type>::value>* = 0);
390 
397  template<typename VecType>
398  double Predict(const VecType& point) const;
399 
407  template<typename MatType>
408  void Predict(const MatType& data,
409  arma::Row<double>& predictions) const;
410 
414  template<typename Archive>
415  void serialize(Archive& ar, const uint32_t /* version */);
416 
418  size_t NumChildren() const { return children.size(); }
419 
421  size_t NumLeaves() const;
422 
424  const DecisionTreeRegressor& Child(const size_t i) const
425  {
426  return *children[i];
427  }
429  DecisionTreeRegressor& Child(const size_t i) { return *children[i]; }
430 
433  size_t SplitDimension() const { return splitDimension; }
434 
442  template<typename VecType>
443  size_t CalculateDirection(const VecType& point) const;
444 
445  private:
447  std::vector<DecisionTreeRegressor*> children;
449  size_t splitDimension;
452  size_t dimensionType;
460  double splitPointOrPrediction;
461 
465  typedef typename NumericSplit::AuxiliarySplitInfo
466  NumericAuxiliarySplitInfo;
467  typedef typename CategoricalSplit::AuxiliarySplitInfo
468  CategoricalAuxiliarySplitInfo;
469 
488  template<bool UseWeights, typename MatType, typename ResponsesType>
489  double Train(MatType& data,
490  const size_t begin,
491  const size_t count,
492  const data::DatasetInfo& datasetInfo,
493  ResponsesType& responses,
494  arma::rowvec& weights,
495  const size_t minimumLeafSize,
496  const double minimumGainSplit,
497  const size_t maximumDepth,
498  DimensionSelectionType& dimensionSelector,
499  FitnessFunction fitnessFunction = FitnessFunction());
500 
518  template<bool UseWeights, typename MatType, typename ResponsesType>
519  double Train(MatType& data,
520  const size_t begin,
521  const size_t count,
522  ResponsesType& responses,
523  arma::rowvec& weights,
524  const size_t minimumLeafSize,
525  const double minimumGainSplit,
526  const size_t maximumDepth,
527  DimensionSelectionType& dimensionSelector,
528  FitnessFunction fitnessFunction = FitnessFunction());
529 };
530 
531 
532 } // namespace tree
533 } // namespace mlpack
534 
535 // Include implementation.
536 #include "decision_tree_regressor_impl.hpp"
537 
538 #endif
void serialize(Archive &ar, const uint32_t)
Serialize the tree.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
double Predict(const VecType &point) const
Make prediction for the given point, using the entire tree.
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
Linear algebra utility functions, generally performed on matrices or vectors.
const DecisionTreeRegressor & Child(const size_t i) const
Get the child of the given index.
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
DecisionTreeRegressor & operator=(const DecisionTreeRegressor &other)
Copy another tree.
The core includes that mlpack expects; standard C++ includes and Armadillo.
~DecisionTreeRegressor()
Clean up memory.
double Train(MatType data, const data::DatasetInfo &datasetInfo, ResponsesType responses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType(), FitnessFunction fitnessFunction=FitnessFunction())
Train the decision tree on the given data.
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
size_t NumLeaves() const
Get the number of leaves in the tree.
DecisionTreeRegressor & Child(const size_t i)
Modify the child of the given index (be careful!).
This class implements a generic decision tree learner.
DecisionTreeRegressor()
Construct a decision tree without training it.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
size_t NumChildren() const
Get the number of children.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.