best_binary_numeric_split.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_HPP
13 #define MLPACK_METHODS_DECISION_TREE_BEST_BINARY_NUMERIC_SPLIT_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "mse_gain.hpp"
17 
19 
20 namespace mlpack {
21 namespace tree {
22 
23 // This gives us a HasBinaryGains<T, U> type (where U is a function pointer)
24 // we can use with SFINAE to catch when a type has a BinaryGains(...) function.
25 HAS_MEM_FUNC(BinaryGains, HasBinaryGains);
26 
27 // This struct will have `value` set to `true` if a BinaryGains() function of
28 // the right signature is detected. We only check for BinaryGains(), and not
29 // BinaryScanInitialize() or BinaryStep(), because those two are template
30 // members functions and would make this check far more difficult.
31 //
32 // The unused UseWeights template parameter is necessary to ensure that the
33 // compiler thinks the result `value` depends on a parameter specific to the
34 // SplitIfBetter() function in BestBinaryNumericSplit().
35 template<typename T, bool /* UseWeights */>
37 {
38  const static bool value = HasBinaryGains<T,
39  std::tuple<double, double>(T::*)()>::value;
40 };
41 
48 template<typename FitnessFunction>
50 {
51  public:
52  // No extra info needed for split.
53  class AuxiliarySplitInfo { };
54 
76  template<bool UseWeights, typename VecType, typename WeightVecType>
77  static double SplitIfBetter(
78  const double bestGain,
79  const VecType& data,
80  const arma::Row<size_t>& labels,
81  const size_t numClasses,
82  const WeightVecType& weights,
83  const size_t minimumLeafSize,
84  const double minimumGainSplit,
85  arma::vec& splitInfo,
86  AuxiliarySplitInfo& aux);
87 
110  template<bool UseWeights, typename VecType, typename ResponsesType,
111  typename WeightVecType>
112  static typename std::enable_if<
114  double>::type
115  SplitIfBetter(
116  const double bestGain,
117  const VecType& data,
118  const ResponsesType& responses,
119  const WeightVecType& weights,
120  const size_t minimumLeafSize,
121  const double minimumGainSplit,
122  double& splitInfo,
123  AuxiliarySplitInfo& aux,
124  FitnessFunction& fitnessFunction);
125 
147  template<bool UseWeights, typename VecType, typename ResponsesType,
148  typename WeightVecType>
149  static typename std::enable_if<
151  double>::type
152  SplitIfBetter(
153  const double bestGain,
154  const VecType& data,
155  const ResponsesType& responses,
156  const WeightVecType& weights,
157  const size_t minimumLeafSize,
158  const double minimumGainSplit,
159  double& splitInfo,
160  AuxiliarySplitInfo& /* aux */,
161  FitnessFunction& fitnessFunction);
162 
166  static size_t NumChildren(const double& /* splitInfo */,
167  const AuxiliarySplitInfo& /* aux */)
168  {
169  return 2;
170  }
171 
179  template<typename ElemType>
180  static size_t CalculateDirection(
181  const ElemType& point,
182  const double& splitInfo,
183  const AuxiliarySplitInfo& /* aux */);
184 };
185 
186 } // namespace tree
187 } // namespace mlpack
188 
189 // Include implementation.
190 #include "best_binary_numeric_split_impl.hpp"
191 
192 #endif
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
HAS_MEM_FUNC(BinaryGains, HasBinaryGains)
static size_t NumChildren(const double &, const AuxiliarySplitInfo &)
Returns 2, since the binary split always has two children.