mse_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "utils.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
28 class MSEGain
29 {
30  public:
43  template<bool UseWeights, typename VecType, typename WeightVecType>
44  static double Evaluate(const VecType& values,
45  const WeightVecType& weights,
46  const size_t begin,
47  const size_t end)
48  {
49  double mse = 0.0;
50 
51  if (UseWeights)
52  {
53  double accWeights = 0.0;
54  double weightedMean = 0.0;
55  WeightedSum(values, weights, begin, end, accWeights, weightedMean);
56 
57  // Catch edge case: if there are no weights, the impurity is zero.
58  if (accWeights == 0.0)
59  return 0.0;
60 
61  weightedMean /= accWeights;
62 
63  for (size_t i = begin; i < end; ++i)
64  mse += weights[i] * std::pow(values[i] - weightedMean, 2);
65 
66  mse /= accWeights;
67  }
68  else
69  {
70  double mean = 0.0;
71  Sum(values, begin, end, mean);
72  mean /= (double) (end - begin);
73 
74  mse = arma::accu(arma::square(values.subvec(begin, end - 1) - mean));
75  mse /= (double) (end - begin);
76  }
77 
78  return -mse;
79  }
80 
87  template<bool UseWeights, typename VecType, typename WeightVecType>
88  static double Evaluate(const VecType& values,
89  const WeightVecType& weights)
90  {
91  // Corner case: if there are no elements, the impurity is zero.
92  if (values.n_elem == 0)
93  return 0.0;
94 
95  return Evaluate<UseWeights>(values, weights, 0, values.n_elem);
96  }
97 
103  template<bool UseWeights, typename ResponsesType, typename WeightsType>
104  double OutputLeafValue(const ResponsesType& responses,
105  const WeightsType& weights)
106  {
107  if (UseWeights)
108  {
109  double accWeights, weightedSum;
110  WeightedSum(responses, weights, 0, responses.n_elem, accWeights,
111  weightedSum);
112  return weightedSum / accWeights;
113  }
114  else
115  {
116  double sum;
117  Sum(responses, 0, responses.n_elem, sum);
118  return sum / responses.n_elem;
119  }
120  }
121 
133  std::tuple<double, double> BinaryGains()
134  {
135  double mseLeft = leftSumSquares / leftSize - leftMean * leftMean;
136  double mseRight = (totalSumSquares - leftSumSquares) / rightSize
137  - rightMean * rightMean;
138 
139  return std::make_tuple(-mseLeft, -mseRight);
140  }
141 
150  template<bool UseWeights, typename ResponsesType, typename WeightVecType>
151  void BinaryScanInitialize(const ResponsesType& responses,
152  const WeightVecType& weights,
153  const size_t minimum)
154  {
155  typedef typename ResponsesType::elem_type RType;
156  typedef typename WeightVecType::elem_type WType;
157 
158  // Initializing data members to cache statistics.
159  leftMean = 0.0;
160  rightMean = 0.0;
161  leftSize = 0.0;
162  rightSize = 0.0;
163  leftSumSquares = 0.0;
164  totalSumSquares = 0.0;
165 
166  if (UseWeights)
167  {
168  totalSumSquares = arma::accu(weights % arma::square(responses));
169  for (size_t i = 0; i < minimum - 1; ++i)
170  {
171  const WType w = weights[i];
172  const RType x = responses[i];
173 
174  // Calculating initial weighted mean of responses for the left child.
175  leftSize += w;
176  leftMean += w * x;
177  leftSumSquares += w * x * x;
178  }
179  if (leftSize > 1e-9)
180  leftMean /= leftSize;
181 
182  for (size_t i = minimum - 1; i < responses.n_elem; ++i)
183  {
184  const WType w = weights[i];
185  const RType x = responses[i];
186 
187  // Calculating initial weighted mean of responses for the right child.
188  rightSize += w;
189  rightMean += w * x;
190  }
191  if (rightSize > 1e-9)
192  rightMean /= rightSize;
193  }
194  else
195  {
196  totalSumSquares = arma::accu(arma::square(responses));
197  for (size_t i = 0; i < minimum - 1; ++i)
198  {
199  const RType x = responses[i];
200 
201  // Calculating the initial mean of responses for the left child.
202  ++leftSize;
203  leftMean += x;
204  leftSumSquares += x * x;
205  }
206  if (leftSize > 1e-9)
207  leftMean /= leftSize;
208 
209  for (size_t i = minimum - 1; i < responses.n_elem; ++i)
210  {
211  const RType x = responses[i];
212 
213  // Calculating the initial mean of responses for the right child.
214  ++rightSize;
215  rightMean += x;
216  }
217  if (rightSize > 1e-9)
218  rightMean /= rightSize;
219  }
220  }
221 
229  template<bool UseWeights, typename ResponsesType, typename WeightVecType>
230  void BinaryStep(const ResponsesType& responses,
231  const WeightVecType& weights,
232  const size_t index)
233  {
234  typedef typename ResponsesType::elem_type RType;
235  typedef typename WeightVecType::elem_type WType;
236 
237  if (UseWeights)
238  {
239  const WType w = weights[index];
240  const RType x = responses[index];
241 
242  // Update weighted sum of squares for left child.
243  leftSumSquares += w * x * x;
244 
245  // Update weighted mean for both childs.
246  leftMean = (leftMean * leftSize + w * x) / (leftSize + w);
247  leftSize += w;
248 
249  rightMean = (rightMean * rightSize - w * x) / (rightSize - w);
250  rightSize -= w;
251  }
252  else
253  {
254  const RType x = responses[index];
255 
256  // Update sum of squares for left child.
257  leftSumSquares += x * x;
258 
259  // Update mean for both childs.
260  leftMean = (leftMean * leftSize + x) / (leftSize + 1);
261  ++leftSize;
262 
263  rightMean = (rightMean * rightSize - x) / (rightSize - 1);
264  --rightSize;
265  }
266  }
267 
268  private:
273  // Stores the sum of squares / weighted sum of squares for the left child.
274  double leftSumSquares;
275  // For unweighted data, stores the number of elements in each child.
276  // For weighted data, stores the sum of weights of elements in each
277  // child.
278  double leftSize;
279  double rightSize;
280  // Stores the mean / weighted mean.
281  double leftMean;
282  double rightMean;
283  // Stores the total sum of squares / total weighted sum of squares.
284  double totalSumSquares;
285 };
286 
287 } // namespace tree
288 } // namespace mlpack
289 
290 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const VecType &values, const WeightVecType &weights)
Evaluate the MSE gain on the complete vector.
Definition: mse_gain.hpp:88
void BinaryStep(const ResponsesType &responses, const WeightVecType &weights, const size_t index)
Updates the statistics for the given index.
Definition: mse_gain.hpp:230
static double Evaluate(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end)
Evaluate the mean squared error gain of values from begin to end index.
Definition: mse_gain.hpp:44
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
Definition: utils.hpp:19
The MSE (Mean squared error) gain, is a measure of set purity based on the variance of response value...
Definition: mse_gain.hpp:28
void BinaryScanInitialize(const ResponsesType &responses, const WeightVecType &weights, const size_t minimum)
Caches the prefix sum of squares to efficiently compute gain value for each split.
Definition: mse_gain.hpp:151
double OutputLeafValue(const ResponsesType &responses, const WeightsType &weights)
Returns the output value for each leaf node for prediction.
Definition: mse_gain.hpp:104
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
Definition: utils.hpp:96
std::tuple< double, double > BinaryGains()
Calculates the mean squared error gain for the left and right children for the current index...
Definition: mse_gain.hpp:133