information_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
26 {
27  public:
31  template<bool UseWeights, typename CountType>
32  static double EvaluatePtr(const CountType* counts,
33  const size_t countLength,
34  const CountType totalCount)
35  {
36  double gain = 0.0;
37 
38  for (size_t i = 0; i < countLength; ++i)
39  {
40  const double f = ((double) counts[i] / (double) totalCount);
41  if (f > 0.0)
42  gain += f * std::log2(f);
43  }
44 
45  return gain;
46  }
47 
59  template<bool UseWeights>
60  static double Evaluate(const arma::Row<size_t>& labels,
61  const size_t numClasses,
62  const arma::Row<double>& weights)
63  {
64  // Edge case: if there are no elements, the gain is zero.
65  if (labels.n_elem == 0)
66  return 0.0;
67 
68  // Calculate the information gain.
69  double gain = 0.0;
70 
71  // Count the number of elements in each class. Use four auxiliary vectors
72  // to exploit SIMD instructions if possible.
73  arma::vec countSpace(4 * numClasses, arma::fill::zeros);
74  arma::vec counts(countSpace.memptr(), numClasses, false, true);
75  arma::vec counts2(countSpace.memptr() + numClasses, numClasses, false,
76  true);
77  arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses, false,
78  true);
79  arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses, false,
80  true);
81 
82  if (UseWeights)
83  {
84  // Sum all the weights up.
85  double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
86 
87  // SIMD loop: add counts for four elements simultaneously (if the compiler
88  // manages to vectorize the loop).
89  for (size_t i = 3; i < labels.n_elem; i += 4)
90  {
91  const double weight1 = weights[i - 3];
92  const double weight2 = weights[i - 2];
93  const double weight3 = weights[i - 1];
94  const double weight4 = weights[i];
95 
96  counts[labels[i - 3]] += weight1;
97  counts2[labels[i - 2]] += weight2;
98  counts3[labels[i - 1]] += weight3;
99  counts4[labels[i]] += weight4;
100 
101  accWeights[0] += weight1;
102  accWeights[1] += weight2;
103  accWeights[2] += weight3;
104  accWeights[3] += weight4;
105  }
106 
107  // Handle leftovers.
108  if (labels.n_elem % 4 == 1)
109  {
110  const double weight1 = weights[labels.n_elem - 1];
111  counts[labels[labels.n_elem - 1]] += weight1;
112  accWeights[0] += weight1;
113  }
114  else if (labels.n_elem % 4 == 2)
115  {
116  const double weight1 = weights[labels.n_elem - 2];
117  const double weight2 = weights[labels.n_elem - 1];
118 
119  counts[labels[labels.n_elem - 2]] += weight1;
120  counts2[labels[labels.n_elem - 1]] += weight2;
121 
122  accWeights[0] += weight1;
123  accWeights[1] += weight2;
124  }
125  else if (labels.n_elem % 4 == 3)
126  {
127  const double weight1 = weights[labels.n_elem - 3];
128  const double weight2 = weights[labels.n_elem - 2];
129  const double weight3 = weights[labels.n_elem - 1];
130 
131  counts[labels[labels.n_elem - 3]] += weight1;
132  counts2[labels[labels.n_elem - 2]] += weight2;
133  counts3[labels[labels.n_elem - 1]] += weight3;
134 
135  accWeights[0] += weight1;
136  accWeights[1] += weight2;
137  accWeights[2] += weight3;
138  }
139 
140  accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
141  counts += counts2 + counts3 + counts4;
142 
143  // Corner case: return 0 if no weight.
144  if (accWeights[0] == 0.0)
145  return 0.0;
146 
147  for (size_t i = 0; i < numClasses; ++i)
148  {
149  const double f = ((double) counts[i] / (double) accWeights[0]);
150  if (f > 0.0)
151  gain += f * std::log2(f);
152  }
153  }
154  else
155  {
156  // SIMD loop: add counts for four elements simultaneously (if the compiler
157  // manages to vectorize the loop).
158  for (size_t i = 3; i < labels.n_elem; i += 4)
159  {
160  counts[labels[i - 3]]++;
161  counts2[labels[i - 2]]++;
162  counts3[labels[i - 1]]++;
163  counts4[labels[i]]++;
164  }
165 
166  // Handle leftovers.
167  if (labels.n_elem % 4 == 1)
168  {
169  counts[labels[labels.n_elem - 1]]++;
170  }
171  else if (labels.n_elem % 4 == 2)
172  {
173  counts[labels[labels.n_elem - 2]]++;
174  counts2[labels[labels.n_elem - 1]]++;
175  }
176  else if (labels.n_elem % 4 == 3)
177  {
178  counts[labels[labels.n_elem - 3]]++;
179  counts2[labels[labels.n_elem - 2]]++;
180  counts3[labels[labels.n_elem - 1]]++;
181  }
182 
183  counts += counts2 + counts3 + counts4;
184 
185  for (size_t i = 0; i < numClasses; ++i)
186  {
187  const double f = ((double) counts[i] / (double) labels.n_elem);
188  if (f > 0.0)
189  gain += f * std::log2(f);
190  }
191  }
192 
193  return gain;
194  }
195 
203  static double Range(const size_t numClasses)
204  {
205  // The best possible case gives an information gain of 0. The worst
206  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
207  // log2(1/n) = -log2(n). So, the range is log2(n).
208  return std::log2(numClasses);
209  }
210 };
211 
212 } // namespace tree
213 } // namespace mlpack
214 
215 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Row< size_t > &labels, const size_t numClasses, const arma::Row< double > &weights)
Given a set of labels, calculate the information gain of those labels.
The standard information gain criterion, used for calculating gain in decision trees.
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.
static double EvaluatePtr(const CountType *counts, const size_t countLength, const CountType totalCount)
Evaluate the information gain given a vector of class weight counts.