13 #ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP 33 template<
bool UseWeights,
typename CountType>
35 const size_t countLength,
36 const CountType totalCount)
41 CountType impurity = 0.0;
42 for (
size_t i = 0; i < countLength; ++i)
43 impurity += counts[i] * (totalCount - counts[i]);
45 return -((double) impurity / ((
double) std::pow(totalCount, 2)));
61 template<
bool UseWeights,
typename RowType,
typename WeightVecType>
63 const size_t numClasses,
64 const WeightVecType& weights)
67 if (labels.n_elem == 0)
72 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
74 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
76 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
78 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
82 double impurity = 0.0;
87 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
91 for (
size_t i = 3; i < labels.n_elem; i += 4)
93 const double weight1 = weights[i - 3];
94 const double weight2 = weights[i - 2];
95 const double weight3 = weights[i - 1];
96 const double weight4 = weights[i];
98 counts[labels[i - 3]] += weight1;
99 counts2[labels[i - 2]] += weight2;
100 counts3[labels[i - 1]] += weight3;
101 counts4[labels[i]] += weight4;
103 accWeights[0] += weight1;
104 accWeights[1] += weight2;
105 accWeights[2] += weight3;
106 accWeights[3] += weight4;
110 if (labels.n_elem % 4 == 1)
112 const double weight1 = weights[labels.n_elem - 1];
113 counts[labels[labels.n_elem - 1]] += weight1;
114 accWeights[0] += weight1;
116 else if (labels.n_elem % 4 == 2)
118 const double weight1 = weights[labels.n_elem - 2];
119 const double weight2 = weights[labels.n_elem - 1];
121 counts[labels[labels.n_elem - 2]] += weight1;
122 counts2[labels[labels.n_elem - 1]] += weight2;
124 accWeights[0] += weight1;
125 accWeights[1] += weight2;
127 else if (labels.n_elem % 4 == 3)
129 const double weight1 = weights[labels.n_elem - 3];
130 const double weight2 = weights[labels.n_elem - 2];
131 const double weight3 = weights[labels.n_elem - 1];
133 counts[labels[labels.n_elem - 3]] += weight1;
134 counts2[labels[labels.n_elem - 2]] += weight2;
135 counts3[labels[labels.n_elem - 1]] += weight3;
137 accWeights[0] += weight1;
138 accWeights[1] += weight2;
139 accWeights[2] += weight3;
142 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
143 counts += counts2 + counts3 + counts4;
146 if (accWeights[0] == 0.0)
149 for (
size_t i = 0; i < numClasses; ++i)
151 const double f = ((double) counts[i] / (
double) accWeights[0]);
152 impurity += f * (1.0 - f);
159 for (
size_t i = 3; i < labels.n_elem; i += 4)
161 counts[labels[i - 3]]++;
162 counts2[labels[i - 2]]++;
163 counts3[labels[i - 1]]++;
164 counts4[labels[i]]++;
168 if (labels.n_elem % 4 == 1)
170 counts[labels[labels.n_elem - 1]]++;
172 else if (labels.n_elem % 4 == 2)
174 counts[labels[labels.n_elem - 2]]++;
175 counts2[labels[labels.n_elem - 1]]++;
177 else if (labels.n_elem % 4 == 3)
179 counts[labels[labels.n_elem - 3]]++;
180 counts2[labels[labels.n_elem - 2]]++;
181 counts3[labels[labels.n_elem - 1]]++;
184 counts += counts2 + counts3 + counts4;
186 for (
size_t i = 0; i < numClasses; ++i)
188 const double f = ((double) counts[i] / (
double) labels.n_elem);
189 impurity += f * (1.0 - f);
203 static double Range(
const size_t numClasses)
208 return 1.0 - (1.0 / double(numClasses));
Linear algebra utility functions, generally performed on matrices or vectors.
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.