gini_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_GINI_GAIN_HPP
15 
16 #include <mlpack/core.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
27 class GiniGain
28 {
29  public:
33  template<bool UseWeights, typename CountType>
34  static double EvaluatePtr(const CountType* counts,
35  const size_t countLength,
36  const CountType totalCount)
37  {
38  if (totalCount == 0)
39  return 0.0;
40 
41  CountType impurity = 0.0;
42  for (size_t i = 0; i < countLength; ++i)
43  impurity += counts[i] * (totalCount - counts[i]);
44 
45  return -((double) impurity / ((double) std::pow(totalCount, 2)));
46  }
47 
61  template<bool UseWeights, typename RowType, typename WeightVecType>
62  static double Evaluate(const RowType& labels,
63  const size_t numClasses,
64  const WeightVecType& weights)
65  {
66  // Corner case: if there are no elements, the impurity is zero.
67  if (labels.n_elem == 0)
68  return 0.0;
69 
70  // Count the number of elements in each class. Use four auxiliary vectors
71  // to exploit SIMD instructions if possible.
72  arma::vec countSpace(4 * numClasses, arma::fill::zeros);
73  arma::vec counts(countSpace.memptr(), numClasses, false, true);
74  arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
75  true);
76  arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
77  true);
78  arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
79  true);
80 
81  // Calculate the Gini impurity of the un-split node.
82  double impurity = 0.0;
83 
84  if (UseWeights)
85  {
86  // Sum all the weights up.
87  double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
88 
89  // SIMD loop: add counts for four elements simultaneously (if the compiler
90  // manages to vectorize the loop).
91  for (size_t i = 3; i < labels.n_elem; i += 4)
92  {
93  const double weight1 = weights[i - 3];
94  const double weight2 = weights[i - 2];
95  const double weight3 = weights[i - 1];
96  const double weight4 = weights[i];
97 
98  counts[labels[i - 3]] += weight1;
99  counts2[labels[i - 2]] += weight2;
100  counts3[labels[i - 1]] += weight3;
101  counts4[labels[i]] += weight4;
102 
103  accWeights[0] += weight1;
104  accWeights[1] += weight2;
105  accWeights[2] += weight3;
106  accWeights[3] += weight4;
107  }
108 
109  // Handle leftovers.
110  if (labels.n_elem % 4 == 1)
111  {
112  const double weight1 = weights[labels.n_elem - 1];
113  counts[labels[labels.n_elem - 1]] += weight1;
114  accWeights[0] += weight1;
115  }
116  else if (labels.n_elem % 4 == 2)
117  {
118  const double weight1 = weights[labels.n_elem - 2];
119  const double weight2 = weights[labels.n_elem - 1];
120 
121  counts[labels[labels.n_elem - 2]] += weight1;
122  counts2[labels[labels.n_elem - 1]] += weight2;
123 
124  accWeights[0] += weight1;
125  accWeights[1] += weight2;
126  }
127  else if (labels.n_elem % 4 == 3)
128  {
129  const double weight1 = weights[labels.n_elem - 3];
130  const double weight2 = weights[labels.n_elem - 2];
131  const double weight3 = weights[labels.n_elem - 1];
132 
133  counts[labels[labels.n_elem - 3]] += weight1;
134  counts2[labels[labels.n_elem - 2]] += weight2;
135  counts3[labels[labels.n_elem - 1]] += weight3;
136 
137  accWeights[0] += weight1;
138  accWeights[1] += weight2;
139  accWeights[2] += weight3;
140  }
141 
142  accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
143  counts += counts2 + counts3 + counts4;
144 
145  // Catch edge case: if there are no weights, the impurity is zero.
146  if (accWeights[0] == 0.0)
147  return 0.0;
148 
149  for (size_t i = 0; i < numClasses; ++i)
150  {
151  const double f = ((double) counts[i] / (double) accWeights[0]);
152  impurity += f * (1.0 - f);
153  }
154  }
155  else
156  {
157  // SIMD loop: add counts for four elements simultaneously (if the compiler
158  // manages to vectorize the loop).
159  for (size_t i = 3; i < labels.n_elem; i += 4)
160  {
161  counts[labels[i - 3]]++;
162  counts2[labels[i - 2]]++;
163  counts3[labels[i - 1]]++;
164  counts4[labels[i]]++;
165  }
166 
167  // Handle leftovers.
168  if (labels.n_elem % 4 == 1)
169  {
170  counts[labels[labels.n_elem - 1]]++;
171  }
172  else if (labels.n_elem % 4 == 2)
173  {
174  counts[labels[labels.n_elem - 2]]++;
175  counts2[labels[labels.n_elem - 1]]++;
176  }
177  else if (labels.n_elem % 4 == 3)
178  {
179  counts[labels[labels.n_elem - 3]]++;
180  counts2[labels[labels.n_elem - 2]]++;
181  counts3[labels[labels.n_elem - 1]]++;
182  }
183 
184  counts += counts2 + counts3 + counts4;
185 
186  for (size_t i = 0; i < numClasses; ++i)
187  {
188  const double f = ((double) counts[i] / (double) labels.n_elem);
189  impurity += f * (1.0 - f);
190  }
191  }
192 
193  return -impurity;
194  }
195 
203  static double Range(const size_t numClasses)
204  {
205  // The best possible case is that only one class exists, which gives a Gini
206  // impurity of 0. The worst possible case is that the classes are evenly
207  // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
208  return 1.0 - (1.0 / double(numClasses));
209  }
210 };
211 
212 } // namespace tree
213 } // namespace mlpack
214 
215 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
static double Evaluate(const RowType &labels, const size_t numClasses, const WeightVecType &weights)
Evaluate the Gini impurity on the given set of labels.
Definition: gini_gain.hpp:62
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the Gini impurity given a vector of class weight counts.
Definition: gini_gain.hpp:34
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Definition: gini_gain.hpp:203