information_gain.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_INFORMATION_GAIN_HPP
15 
16 namespace mlpack {
17 namespace tree {
18 
20 {
21  public:
31  static double Evaluate(const arma::Mat<size_t>& counts)
32  {
33  // Calculate the number of elements in the unsplit node and also in each
34  // proposed child.
35  size_t numElem = 0;
36  arma::vec splitCounts(counts.n_elem);
37  for (size_t i = 0; i < counts.n_cols; ++i)
38  {
39  splitCounts[i] = arma::accu(counts.col(i));
40  numElem += splitCounts[i];
41  }
42 
43  // Corner case: if there are no elements, the gain is zero.
44  if (numElem == 0)
45  return 0.0;
46 
47  arma::Col<size_t> classCounts = arma::sum(counts, 1);
48 
49  // Calculate the gain of the unsplit node.
50  double gain = 0.0;
51  for (size_t i = 0; i < classCounts.n_elem; ++i)
52  {
53  const double f = ((double) classCounts[i] / (double) numElem);
54  if (f > 0.0)
55  gain -= f * std::log2(f);
56  }
57 
58  // Now calculate the impurity of the split nodes and subtract them from the
59  // overall gain.
60  for (size_t i = 0; i < counts.n_cols; ++i)
61  {
62  if (splitCounts[i] > 0)
63  {
64  double splitGain = 0.0;
65  for (size_t j = 0; j < counts.n_rows; ++j)
66  {
67  const double f = ((double) counts(j, i) / (double) splitCounts[i]);
68  if (f > 0.0)
69  splitGain += f * std::log2(f);
70  }
71 
72  gain += ((double) splitCounts[i] / (double) numElem) * splitGain;
73  }
74  }
75 
76  return gain;
77  }
78 
84  static double Range(const size_t numClasses)
85  {
86  // The best possible case gives an information gain of 0. The worst
87  // possible case is even distribution, which gives n * (1/n * log2(1/n)) =
88  // log2(1/n) = -log2(n). So, the range is log2(n).
89  return std::log2(numClasses);
90  }
91 };
92 
93 } // namespace tree
94 } // namespace mlpack
95 
96 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
static double Evaluate(const arma::Mat< size_t > &counts)
Given the sufficient statistics of a proposed split, calculate the information gain if that split was...
static double Range(const size_t numClasses)
Return the range of the information gain for the given number of classes.