mad_gain.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_DECISION_TREE_MAD_GAIN_HPP
15 #define MLPACK_METHODS_DECISION_TREE_MAD_GAIN_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include "utils.hpp"
19 
20 namespace mlpack {
21 namespace tree {
22 
30 class MADGain
31 {
32  public:
45  template<bool UseWeights, typename VecType, typename WeightVecType>
46  static double Evaluate(const VecType& values,
47  const WeightVecType& weights,
48  const size_t begin,
49  const size_t end)
50  {
51  double mad = 0.0;
52 
53  if (UseWeights)
54  {
55  double accWeights = 0.0;
56  double weightedMean = 0.0;
57 
58  WeightedSum(values, weights, begin, end, accWeights, weightedMean);
59 
60  // Catch edge case: if there are no weights, the impurity is zero.
61  if (accWeights == 0.0)
62  return 0.0;
63 
64  weightedMean /= accWeights;
65 
66  for (size_t i = begin; i < end; ++i)
67  {
68  mad += weights[i] * (std::abs(values[i] - weightedMean));
69  }
70  mad /= accWeights;
71  }
72  else
73  {
74  double mean = 0.0;
75  Sum(values, begin, end, mean);
76  mean /= (double) (end - begin);
77 
78  mad = arma::accu(arma::abs(values.subvec(begin, end - 1) - mean));
79  mad /= (double) (end - begin);
80  }
81 
82  return -mad;
83  }
84 
91  template<bool UseWeights, typename VecType, typename WeightVecType>
92  static double Evaluate(const VecType& values,
93  const WeightVecType& weights)
94  {
95  // Corner case: if there are no elements, the impurity is zero.
96  if (values.n_elem == 0)
97  return 0.0;
98 
99  return Evaluate<UseWeights>(values, weights, 0, values.n_elem);
100  }
101 
107  template<bool UseWeights, typename ResponsesType, typename WeightsType>
108  double OutputLeafValue(const ResponsesType& responses,
109  const WeightsType& weights)
110  {
111  if (UseWeights)
112  {
113  double accWeights, weightedSum;
114  WeightedSum(responses, weights, 0, responses.n_elem, accWeights,
115  weightedSum);
116  return weightedSum / accWeights;
117  }
118  else
119  {
120  double sum;
121  Sum(responses, 0, responses.n_elem, sum);
122  return sum / responses.n_elem;
123  }
124  }
125 };
126 
127 } // namespace tree
128 } // namespace mlpack
129 
130 #endif
static double Evaluate(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end)
Evaluate the mean absolute deviation gain from begin to end index.
Definition: mad_gain.hpp:46
double OutputLeafValue(const ResponsesType &responses, const WeightsType &weights)
Returns the output value for each leaf node for prediction.
Definition: mad_gain.hpp:108
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
Definition: utils.hpp:19
The MAD (Mean absolute deviation) gain, is a measure of set purity based on the deviation of dependen...
Definition: mad_gain.hpp:30
static double Evaluate(const VecType &values, const WeightVecType &weights)
Evaluate the MAD gain on the complete vector.
Definition: mad_gain.hpp:92
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
Definition: utils.hpp:96