13 #ifndef MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 14 #define MLPACK_METHODS_DECISION_TREE_INFORMATION_GAIN_HPP 31 template<
bool UseWeights,
typename CountType>
33 const size_t countLength,
34 const CountType totalCount)
38 for (
size_t i = 0; i < countLength; ++i)
40 const double f = ((double) counts[i] / (
double) totalCount);
42 gain += f * std::log2(f);
59 template<
bool UseWeights>
60 static double Evaluate(
const arma::Row<size_t>& labels,
61 const size_t numClasses,
62 const arma::Row<double>& weights)
65 if (labels.n_elem == 0)
73 arma::vec countSpace(4 * numClasses, arma::fill::zeros);
74 arma::vec counts(countSpace.memptr(), numClasses,
false,
true);
75 arma::vec counts2(countSpace.memptr() + numClasses, numClasses,
false,
77 arma::vec counts3(countSpace.memptr() + 2 * numClasses, numClasses,
false,
79 arma::vec counts4(countSpace.memptr() + 3 * numClasses, numClasses,
false,
85 double accWeights[4] = { 0.0, 0.0, 0.0, 0.0 };
89 for (
size_t i = 3; i < labels.n_elem; i += 4)
91 const double weight1 = weights[i - 3];
92 const double weight2 = weights[i - 2];
93 const double weight3 = weights[i - 1];
94 const double weight4 = weights[i];
96 counts[labels[i - 3]] += weight1;
97 counts2[labels[i - 2]] += weight2;
98 counts3[labels[i - 1]] += weight3;
99 counts4[labels[i]] += weight4;
101 accWeights[0] += weight1;
102 accWeights[1] += weight2;
103 accWeights[2] += weight3;
104 accWeights[3] += weight4;
108 if (labels.n_elem % 4 == 1)
110 const double weight1 = weights[labels.n_elem - 1];
111 counts[labels[labels.n_elem - 1]] += weight1;
112 accWeights[0] += weight1;
114 else if (labels.n_elem % 4 == 2)
116 const double weight1 = weights[labels.n_elem - 2];
117 const double weight2 = weights[labels.n_elem - 1];
119 counts[labels[labels.n_elem - 2]] += weight1;
120 counts2[labels[labels.n_elem - 1]] += weight2;
122 accWeights[0] += weight1;
123 accWeights[1] += weight2;
125 else if (labels.n_elem % 4 == 3)
127 const double weight1 = weights[labels.n_elem - 3];
128 const double weight2 = weights[labels.n_elem - 2];
129 const double weight3 = weights[labels.n_elem - 1];
131 counts[labels[labels.n_elem - 3]] += weight1;
132 counts2[labels[labels.n_elem - 2]] += weight2;
133 counts3[labels[labels.n_elem - 1]] += weight3;
135 accWeights[0] += weight1;
136 accWeights[1] += weight2;
137 accWeights[2] += weight3;
140 accWeights[0] += accWeights[1] + accWeights[2] + accWeights[3];
141 counts += counts2 + counts3 + counts4;
144 if (accWeights[0] == 0.0)
147 for (
size_t i = 0; i < numClasses; ++i)
149 const double f = ((double) counts[i] / (
double) accWeights[0]);
151 gain += f * std::log2(f);
158 for (
size_t i = 3; i < labels.n_elem; i += 4)
160 counts[labels[i - 3]]++;
161 counts2[labels[i - 2]]++;
162 counts3[labels[i - 1]]++;
163 counts4[labels[i]]++;
167 if (labels.n_elem % 4 == 1)
169 counts[labels[labels.n_elem - 1]]++;
171 else if (labels.n_elem % 4 == 2)
173 counts[labels[labels.n_elem - 2]]++;
174 counts2[labels[labels.n_elem - 1]]++;
176 else if (labels.n_elem % 4 == 3)
178 counts[labels[labels.n_elem - 3]]++;
179 counts2[labels[labels.n_elem - 2]]++;
180 counts3[labels[labels.n_elem - 1]]++;
183 counts += counts2 + counts3 + counts4;
185 for (
size_t i = 0; i < numClasses; ++i)
187 const double f = ((double) counts[i] / (
double) labels.n_elem);
189 gain += f * std::log2(f);
203 static double Range(
const size_t numClasses)
208 return std::log2(numClasses);
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.