gini_impurity.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
14 #define MLPACK_METHODS_HOEFFDING_TREES_GINI_INDEX_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
22 {
23  public:
24  static double Evaluate(const arma::Mat<size_t>& counts)
25  {
26  // We need to sum over the difference between the un-split node and the
27  // split nodes. First we'll calculate the number of elements in each split
28  // and total.
29  size_t numElem = 0;
30  arma::vec splitCounts(counts.n_cols);
31  for (size_t i = 0; i < counts.n_cols; ++i)
32  {
33  splitCounts[i] = arma::accu(counts.col(i));
34  numElem += splitCounts[i];
35  }
36 
37  // Corner case: if there are no elements, the impurity is zero.
38  if (numElem == 0)
39  return 0.0;
40 
41  arma::Col<size_t> classCounts = arma::sum(counts, 1);
42 
43  // Calculate the Gini impurity of the un-split node.
44  double impurity = 0.0;
45  for (size_t i = 0; i < classCounts.n_elem; ++i)
46  {
47  const double f = ((double) classCounts[i] / (double) numElem);
48  impurity += f * (1.0 - f);
49  }
50 
51  // Now calculate the impurity of the split nodes and subtract them from the
52  // overall impurity.
53  for (size_t i = 0; i < counts.n_cols; ++i)
54  {
55  if (splitCounts[i] > 0)
56  {
57  double splitImpurity = 0.0;
58  for (size_t j = 0; j < counts.n_rows; ++j)
59  {
60  const double f = ((double) counts(j, i) / (double) splitCounts[i]);
61  splitImpurity += f * (1.0 - f);
62  }
63 
64  impurity -= ((double) splitCounts[i] / (double) numElem) *
65  splitImpurity;
66  }
67  }
68 
69  return impurity;
70  }
71 
77  static double Range(const size_t numClasses)
78  {
79  // The best possible case is that only one class exists, which gives a Gini
80  // impurity of 0. The worst possible case is that the classes are evenly
81  // distributed, which gives n * (1/n * (1 - 1/n)) = 1 - 1/n.
82  return 1.0 - (1.0 / double(numClasses));
83  }
84 };
85 
86 } // namespace tree
87 } // namespace mlpack
88 
89 #endif
static double Range(const size_t numClasses)
Return the range of the Gini impurity for the given number of classes.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static double Evaluate(const arma::Mat< size_t > &counts)