13 #ifndef MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_MSE_GAIN_HPP 43 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
45 const WeightVecType& weights,
53 double accWeights = 0.0;
54 double weightedMean = 0.0;
55 WeightedSum(values, weights, begin, end, accWeights, weightedMean);
58 if (accWeights == 0.0)
61 weightedMean /= accWeights;
63 for (
size_t i = begin; i < end; ++i)
64 mse += weights[i] * std::pow(values[i] - weightedMean, 2);
71 Sum(values, begin, end, mean);
72 mean /= (double) (end - begin);
74 mse = arma::accu(arma::square(values.subvec(begin, end - 1) - mean));
75 mse /= (double) (end - begin);
87 template<
bool UseWeights,
typename VecType,
typename WeightVecType>
89 const WeightVecType& weights)
92 if (values.n_elem == 0)
95 return Evaluate<UseWeights>(values, weights, 0, values.n_elem);
103 template<
bool UseWeights,
typename ResponsesType,
typename WeightsType>
105 const WeightsType& weights)
109 double accWeights, weightedSum;
110 WeightedSum(responses, weights, 0, responses.n_elem, accWeights,
112 return weightedSum / accWeights;
117 Sum(responses, 0, responses.n_elem, sum);
118 return sum / responses.n_elem;
135 double mseLeft = leftSumSquares / leftSize - leftMean * leftMean;
136 double mseRight = (totalSumSquares - leftSumSquares) / rightSize
137 - rightMean * rightMean;
139 return std::make_tuple(-mseLeft, -mseRight);
150 template<
bool UseWeights,
typename ResponsesType,
typename WeightVecType>
152 const WeightVecType& weights,
153 const size_t minimum)
155 typedef typename ResponsesType::elem_type RType;
156 typedef typename WeightVecType::elem_type WType;
163 leftSumSquares = 0.0;
164 totalSumSquares = 0.0;
168 totalSumSquares = arma::accu(weights % arma::square(responses));
169 for (
size_t i = 0; i < minimum - 1; ++i)
171 const WType w = weights[i];
172 const RType x = responses[i];
177 leftSumSquares += w * x * x;
180 leftMean /= leftSize;
182 for (
size_t i = minimum - 1; i < responses.n_elem; ++i)
184 const WType w = weights[i];
185 const RType x = responses[i];
191 if (rightSize > 1e-9)
192 rightMean /= rightSize;
196 totalSumSquares = arma::accu(arma::square(responses));
197 for (
size_t i = 0; i < minimum - 1; ++i)
199 const RType x = responses[i];
204 leftSumSquares += x * x;
207 leftMean /= leftSize;
209 for (
size_t i = minimum - 1; i < responses.n_elem; ++i)
211 const RType x = responses[i];
217 if (rightSize > 1e-9)
218 rightMean /= rightSize;
229 template<
bool UseWeights,
typename ResponsesType,
typename WeightVecType>
231 const WeightVecType& weights,
234 typedef typename ResponsesType::elem_type RType;
235 typedef typename WeightVecType::elem_type WType;
239 const WType w = weights[index];
240 const RType x = responses[index];
243 leftSumSquares += w * x * x;
246 leftMean = (leftMean * leftSize + w * x) / (leftSize + w);
249 rightMean = (rightMean * rightSize - w * x) / (rightSize - w);
254 const RType x = responses[index];
257 leftSumSquares += x * x;
260 leftMean = (leftMean * leftSize + x) / (leftSize + 1);
263 rightMean = (rightMean * rightSize - x) / (rightSize - 1);
274 double leftSumSquares;
284 double totalSumSquares;
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const VecType &values, const WeightVecType &weights)
Evaluate the MSE gain on the complete vector.
void BinaryStep(const ResponsesType &responses, const WeightVecType &weights, const size_t index)
Updates the statistics for the given index.
static double Evaluate(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end)
Evaluate the mean squared error gain of values from begin to end index.
void WeightedSum(const VecType &values, const WeightVecType &weights, const size_t begin, const size_t end, double &accWeights, double &weightedMean)
Calculates the weighted sum and total weight of labels.
The MSE (Mean squared error) gain, is a measure of set purity based on the variance of response value...
void BinaryScanInitialize(const ResponsesType &responses, const WeightVecType &weights, const size_t minimum)
Caches the prefix sum of squares to efficiently compute gain value for each split.
double OutputLeafValue(const ResponsesType &responses, const WeightsType &weights)
Returns the output value for each leaf node for prediction.
void Sum(const VecType &values, const size_t begin, const size_t end, double &mean)
Sums up the labels vector.
std::tuple< double, double > BinaryGains()
Calculates the mean squared error gain for the left and right children for the current index...