all_categorical_split.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
14 #define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
29 template<typename FitnessFunction>
31 {
32  public:
33  // No extra info needed for split.
34  class AuxiliarySplitInfo { };
35 
59  template<bool UseWeights, typename VecType, typename LabelsType,
60  typename WeightVecType>
61  static double SplitIfBetter(
62  const double bestGain,
63  const VecType& data,
64  const size_t numCategories,
65  const LabelsType& labels,
66  const size_t numClasses,
67  const WeightVecType& weights,
68  const size_t minimumLeafSize,
69  const double minimumGainSplit,
70  arma::vec& splitInfo,
71  AuxiliarySplitInfo& aux);
72 
97  template<bool UseWeights, typename VecType, typename ResponsesType,
98  typename WeightVecType>
99  static double SplitIfBetter(
100  const double bestGain,
101  const VecType& data,
102  const size_t numCategories,
103  const ResponsesType& responses,
104  const WeightVecType& weights,
105  const size_t minimumLeafSize,
106  const double minimumGainSplit,
107  double& splitInfo,
108  AuxiliarySplitInfo& aux,
109  FitnessFunction& fitnessFunction);
110 
117  static size_t NumChildren(const double& splitInfo,
118  const AuxiliarySplitInfo& /* aux */);
119 
127  template<typename ElemType>
128  static size_t CalculateDirection(
129  const ElemType& point,
130  const double& splitInfo,
131  const AuxiliarySplitInfo& /* aux */);
132 };
133 
134 } // namespace tree
135 } // namespace mlpack
136 
137 // Include implementation.
138 #include "all_categorical_split_impl.hpp"
139 
140 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
static size_t CalculateDirection(const ElemType &point, const double &splitInfo, const AuxiliarySplitInfo &)
Calculate the direction a point should percolate to.
static size_t NumChildren(const double &splitInfo, const AuxiliarySplitInfo &)
Return the number of children in the split.
static double SplitIfBetter(const double bestGain, const VecType &data, const size_t numCategories, const LabelsType &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::vec &splitInfo, AuxiliarySplitInfo &aux)
Check if we can split a node.