random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
39 template<typename FitnessFunction = GiniGain,
40  typename DimensionSelectionType = MultipleRandomDimensionSelect,
41  template<typename> class NumericSplitType = BestBinaryNumericSplit,
42  template<typename> class CategoricalSplitType = AllCategoricalSplit,
43  bool UseBootstrap = true>
45 {
46  public:
48  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
49  DimensionSelectionType> DecisionTreeType;
50 
55  RandomForest();
56 
73  template<typename MatType>
74  RandomForest(const MatType& dataset,
75  const arma::Row<size_t>& labels,
76  const size_t numClasses,
77  const size_t numTrees = 20,
78  const size_t minimumLeafSize = 1,
79  const double minimumGainSplit = 1e-7,
80  const size_t maximumDepth = 0,
81  DimensionSelectionType dimensionSelector =
82  DimensionSelectionType());
83 
102  template<typename MatType>
103  RandomForest(const MatType& dataset,
104  const data::DatasetInfo& datasetInfo,
105  const arma::Row<size_t>& labels,
106  const size_t numClasses,
107  const size_t numTrees = 20,
108  const size_t minimumLeafSize = 1,
109  const double minimumGainSplit = 1e-7,
110  const size_t maximumDepth = 0,
111  DimensionSelectionType dimensionSelector =
112  DimensionSelectionType());
113 
129  template<typename MatType>
130  RandomForest(const MatType& dataset,
131  const arma::Row<size_t>& labels,
132  const size_t numClasses,
133  const arma::rowvec& weights,
134  const size_t numTrees = 20,
135  const size_t minimumLeafSize = 1,
136  const double minimumGainSplit = 1e-7,
137  const size_t maximumDepth = 0,
138  DimensionSelectionType dimensionSelector =
139  DimensionSelectionType());
140 
160  template<typename MatType>
161  RandomForest(const MatType& dataset,
162  const data::DatasetInfo& datasetInfo,
163  const arma::Row<size_t>& labels,
164  const size_t numClasses,
165  const arma::rowvec& weights,
166  const size_t numTrees = 20,
167  const size_t minimumLeafSize = 1,
168  const double minimumGainSplit = 1e-7,
169  const size_t maximumDepth = 0,
170  DimensionSelectionType dimensionSelector =
171  DimensionSelectionType());
172 
192  template<typename MatType>
193  double Train(const MatType& data,
194  const arma::Row<size_t>& labels,
195  const size_t numClasses,
196  const size_t numTrees = 20,
197  const size_t minimumLeafSize = 1,
198  const double minimumGainSplit = 1e-7,
199  const size_t maximumDepth = 0,
200  const bool warmStart = false,
201  DimensionSelectionType dimensionSelector =
202  DimensionSelectionType());
203 
226  template<typename MatType>
227  double Train(const MatType& data,
228  const data::DatasetInfo& datasetInfo,
229  const arma::Row<size_t>& labels,
230  const size_t numClasses,
231  const size_t numTrees = 20,
232  const size_t minimumLeafSize = 1,
233  const double minimumGainSplit = 1e-7,
234  const size_t maximumDepth = 0,
235  const bool warmStart = false,
236  DimensionSelectionType dimensionSelector =
237  DimensionSelectionType());
238 
259  template<typename MatType>
260  double Train(const MatType& data,
261  const arma::Row<size_t>& labels,
262  const size_t numClasses,
263  const arma::rowvec& weights,
264  const size_t numTrees = 20,
265  const size_t minimumLeafSize = 1,
266  const double minimumGainSplit = 1e-7,
267  const size_t maximumDepth = 0,
268  const bool warmStart = false,
269  DimensionSelectionType dimensionSelector =
270  DimensionSelectionType());
271 
294  template<typename MatType>
295  double Train(const MatType& data,
296  const data::DatasetInfo& datasetInfo,
297  const arma::Row<size_t>& labels,
298  const size_t numClasses,
299  const arma::rowvec& weights,
300  const size_t numTrees = 20,
301  const size_t minimumLeafSize = 1,
302  const double minimumGainSplit = 1e-7,
303  const size_t maximumDepth = 0,
304  const bool warmStart = false,
305  DimensionSelectionType dimensionSelector =
306  DimensionSelectionType());
307 
314  template<typename VecType>
315  size_t Classify(const VecType& point) const;
316 
326  template<typename VecType>
327  void Classify(const VecType& point,
328  size_t& prediction,
329  arma::vec& probabilities) const;
330 
338  template<typename MatType>
339  void Classify(const MatType& data,
340  arma::Row<size_t>& predictions) const;
341 
351  template<typename MatType>
352  void Classify(const MatType& data,
353  arma::Row<size_t>& predictions,
354  arma::mat& probabilities) const;
355 
357  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
359  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
360 
362  size_t NumTrees() const { return trees.size(); }
363 
367  template<typename Archive>
368  void serialize(Archive& ar, const uint32_t /* version */);
369 
370  private:
393  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
394  double Train(const MatType& data,
395  const data::DatasetInfo& datasetInfo,
396  const arma::Row<size_t>& labels,
397  const size_t numClasses,
398  const arma::rowvec& weights,
399  const size_t numTrees,
400  const size_t minimumLeafSize,
401  const double minimumGainSplit,
402  const size_t maximumDepth,
403  DimensionSelectionType& dimensionSelector,
404  const bool warmStart = false);
405 
407  std::vector<DecisionTreeType> trees;
408 
410  double avgGain;
411 };
412 
436 template<typename FitnessFunction = GiniGain,
437  typename DimensionSelectionType = MultipleRandomDimensionSelect,
438  template<typename> class CategoricalSplitType = AllCategoricalSplit>
439 using ExtraTrees = RandomForest<FitnessFunction,
440  DimensionSelectionType,
442  CategoricalSplitType,
443  false>;
444 
445 } // namespace tree
446 } // namespace mlpack
447 
448 // Include implementation.
449 #include "random_forest_impl.hpp"
450 
451 #endif
The RandomForest class provides an implementation of random forests, described in Breiman&#39;s seminal p...
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType > DecisionTreeType
Allow access to the underlying decision tree type.
The RandomBinaryNumericSplit is a splitting function for decision trees that will split based on a ra...
size_t NumTrees() const
Get the number of trees in the forest.
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
Linear algebra utility functions, generally performed on matrices or vectors.
This class implements a generic decision tree learner.
void serialize(Archive &ar, const uint32_t)
Serialize the random forest.
This dimension selection policy allows the selection from a few random dimensions.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
RandomForest()
Construct the random forest without any training or specifying the number of trees.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
size_t Classify(const VecType &point) const
Predict the class of the given point.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, const bool warmStart=false, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...