validation_rmse_termination.hpp
Go to the documentation of this file.
1 
12 #ifndef _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
13 #define _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack
18 {
19 namespace amf
20 {
21 
36 template <class MatType>
38 {
39  public:
51  size_t num_test_points,
52  double tolerance = 1e-5,
53  size_t maxIterations = 10000,
54  size_t reverseStepTolerance = 3)
55  : tolerance(tolerance),
56  maxIterations(maxIterations),
57  num_test_points(num_test_points),
58  reverseStepTolerance(reverseStepTolerance)
59  {
60  size_t n = V.n_rows;
61  size_t m = V.n_cols;
62 
63  // initialize validation set matrix
64  test_points.zeros(num_test_points, 3);
65 
66  // fill validation set matrix with random chosen entries
67  for (size_t i = 0; i < num_test_points; ++i)
68  {
69  double t_val;
70  size_t t_row;
71  size_t t_col;
72 
73  // pick a random non-zero entry
74  do
75  {
76  t_row = rand() % n;
77  t_col = rand() % m;
78  } while ((t_val = V(t_row, t_col)) == 0);
79 
80  // add the entry to the validation set
81  test_points(i, 0) = t_row;
82  test_points(i, 1) = t_col;
83  test_points(i, 2) = t_val;
84 
85  // nullify the added entry from data matrix (training set)
86  V(t_row, t_col) = 0;
87  }
88  }
89 
95  void Initialize(const MatType& /* V */)
96  {
97  iteration = 1;
98 
99  rmse = DBL_MAX;
100  rmseOld = DBL_MAX;
101 
102  c_index = 0;
103  c_indexOld = 0;
104 
105  reverseStepCount = 0;
106  isCopy = false;
107  }
108 
115  bool IsConverged(arma::mat& W, arma::mat& H)
116  {
117  arma::mat WH;
118 
119  WH = W * H;
120 
121  // compute validation RMSE
122  if (iteration != 0)
123  {
124  rmseOld = rmse;
125  rmse = 0;
126  for (size_t i = 0; i < num_test_points; ++i)
127  {
128  size_t t_row = test_points(i, 0);
129  size_t t_col = test_points(i, 1);
130  double t_val = test_points(i, 2);
131  double temp = (t_val - WH(t_row, t_col));
132  temp *= temp;
133  rmse += temp;
134  }
135  rmse /= num_test_points;
136  rmse = sqrt(rmse);
137  }
138 
139  // increment iteration count
140  iteration++;
141 
142  // if RMSE tolerance is not satisfied
143  if ((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
144  {
145  // check if this is a first of successive drops
146  if (reverseStepCount == 0 && isCopy == false)
147  {
148  // store a copy of W and H matrix
149  isCopy = true;
150  this->W = W;
151  this->H = H;
152  // store residue values
153  c_indexOld = rmseOld;
154  c_index = rmse;
155  }
156  // increase successive drop count
157  reverseStepCount++;
158  }
159  // if tolerance is satisfied
160  else
161  {
162  // initialize successive drop count
163  reverseStepCount = 0;
164  // if residue is droped below minimum scrap stored values
165  if (rmse <= c_indexOld && isCopy == true)
166  {
167  isCopy = false;
168  }
169  }
170 
171  // check if termination criterion is met
172  if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
173  {
174  // if stored values are present replace them with current value as they
175  // represent the minimum residue point
176  if (isCopy)
177  {
178  W = this->W;
179  H = this->H;
180  rmse = c_index;
181  }
182  return true;
183  }
184  else return false;
185  }
186 
188  const double& Index() const { return rmse; }
189 
191  const size_t& Iteration() const { return iteration; }
192 
194  const size_t& NumTestPoints() const { return num_test_points; }
195 
197  const size_t& MaxIterations() const { return maxIterations; }
198  size_t& MaxIterations() { return maxIterations; }
199 
201  const double& Tolerance() const { return tolerance; }
202  double& Tolerance() { return tolerance; }
203 
204  private:
206  double tolerance;
208  size_t maxIterations;
210  size_t num_test_points;
211 
213  size_t iteration;
214 
216  arma::mat test_points;
217 
219  double rmseOld;
220  double rmse;
221 
223  size_t reverseStepTolerance;
225  size_t reverseStepCount;
226 
229  bool isCopy;
230 
232  arma::mat W;
233  arma::mat H;
234  double c_indexOld;
235  double c_index;
236 }; // class ValidationRMSETermination
237 
238 } // namespace amf
239 } // namespace mlpack
240 
241 
242 #endif // _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
const size_t & Iteration() const
Get current iteration count.
void Initialize(const MatType &)
Initializes the termination policy before stating the factorization.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements validation termination policy based on RMSE index.
const double & Index() const
Get current value of residue.
const size_t & NumTestPoints() const
Get number of validation points.
const double & Tolerance() const
Access tolerance value.
const size_t & MaxIterations() const
Access upper limit of iteration count.
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)
Create a validation set according to given parameters and nullifies this set in data matrix(training ...
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.