complete_incremental_termination.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_COMPLETE_INCREMENTAL_TERMINATION_HPP
13 #define MLPACK_METHODS_AMF_COMPLETE_INCREMENTAL_TERMINATION_HPP
14 
15 namespace mlpack {
16 namespace amf {
17 
28 template<class TerminationPolicy>
30 {
31  public:
38  TerminationPolicy tPolicy = TerminationPolicy()) :
39  tPolicy(tPolicy), incrementalIndex(0), iteration(0)
40  { /* Nothing to do here. */ }
41 
47  template<class MatType>
48  void Initialize(const MatType& V)
49  {
50  tPolicy.Initialize(V);
51 
52  // Get the number of non-zero entries.
53  incrementalIndex = arma::accu(V != 0);
54  iteration = 0;
55  }
56 
63  void Initialize(const arma::sp_mat& V)
64  {
65  tPolicy.Initialize(V);
66 
67  // Get number of non-zero entries
68  incrementalIndex = V.n_nonzero;
69  iteration = 0;
70  }
71 
79  bool IsConverged(arma::mat& W, arma::mat& H)
80  {
81  // Increment iteration count.
82  iteration++;
83 
84  // If iteration count is multiple of incremental index, return wrapped class
85  // function.
86  if (iteration % incrementalIndex == 0)
87  return tPolicy.IsConverged(W, H);
88  else
89  return false;
90  }
91 
93  const double& Index() const { return tPolicy.Index(); }
94 
96  const size_t& Iteration() const { return iteration; }
97 
99  const size_t& MaxIterations() const { return tPolicy.MaxIterations(); }
101  size_t& MaxIterations() { return tPolicy.MaxIterations(); }
102 
104  const TerminationPolicy& TPolicy() const { return tPolicy; }
106  TerminationPolicy& TPolicy() { return tPolicy; }
107 
108  private:
110  TerminationPolicy tPolicy;
111 
114  size_t incrementalIndex;
116  size_t iteration;
117 }; // class CompleteIncrementalTermination
118 
119 } // namespace amf
120 } // namespace mlpack
121 
122 #endif // MLPACK_METHODS_AMF_COMPLETE_INCREMENTAL_TERMINATION_HPP
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterion is met, if the current iteration means that each point has been visite...
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(const arma::sp_mat &V)
Initializes the termination policy before stating the factorization.
const double & Index() const
Get current value of residue.
void Initialize(const MatType &V)
Initializes the termination policy before stating the factorization.
const size_t & Iteration() const
Get current iteration count.
This class acts as a wrapper for basic termination policies to be used by SVDCompleteIncrementalLearn...
CompleteIncrementalTermination(TerminationPolicy tPolicy=TerminationPolicy())
Empty constructor.
TerminationPolicy & TPolicy()
Modify the wrapped termination policy.
const size_t & MaxIterations() const
Access upper limit of iteration count.
const TerminationPolicy & TPolicy() const
Access the wrapped termination policy.
size_t & MaxIterations()
Modify maximum number of iterations.