svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack
18 {
19 namespace amf
20 {
21 
44 template <class MatType>
46 {
47  public:
56  SVDCompleteIncrementalLearning(double u = 0.0001,
57  double kw = 0,
58  double kh = 0)
59  : u(u), kw(kw), kh(kh), currentUserIndex(0), currentItemIndex(0)
60  {
61  // Nothing to do.
62  }
63 
72  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
73  {
74  // Initialize the current score counters.
75  currentUserIndex = 0;
76  currentItemIndex = 0;
77  }
78 
87  inline void WUpdate(const MatType& V,
88  arma::mat& W,
89  const arma::mat& H)
90  {
91  arma::mat deltaW;
92  deltaW.zeros(1, W.n_cols);
93 
94  // Loop until a non-zero entry is found.
95  while (true)
96  {
97  const double val = V(currentItemIndex, currentUserIndex);
98  // Update feature vector if current entry is non-zero and break the loop.
99  if (val != 0)
100  {
101  deltaW += (val - arma::dot(W.row(currentItemIndex),
102  H.col(currentUserIndex))) * H.col(currentUserIndex).t();
103 
104  // Add regularization.
105  if (kw != 0)
106  deltaW -= kw * W.row(currentItemIndex);
107  break;
108  }
109  }
110 
111  W.row(currentItemIndex) += u * deltaW;
112  }
113 
123  inline void HUpdate(const MatType& V,
124  const arma::mat& W,
125  arma::mat& H)
126  {
127  arma::mat deltaH;
128  deltaH.zeros(H.n_rows, 1);
129 
130  const double val = V(currentItemIndex, currentUserIndex);
131 
132  // Update H matrix based on the non-zero entry found in WUpdate function.
133  deltaH += (val - arma::dot(W.row(currentItemIndex),
134  H.col(currentUserIndex))) * W.row(currentItemIndex).t();
135  // Add regularization.
136  if (kh != 0)
137  deltaH -= kh * H.col(currentUserIndex);
138 
139  // Move on to the next entry.
140  currentUserIndex = currentUserIndex + 1;
141  if (currentUserIndex == V.n_rows)
142  {
143  currentUserIndex = 0;
144  currentItemIndex = (currentItemIndex + 1) % V.n_cols;
145  }
146 
147  H.col(currentUserIndex++) += u * deltaH;
148  }
149 
150  private:
152  double u;
154  double kw;
156  double kh;
157 
159  size_t currentUserIndex;
161  size_t currentItemIndex;
162 };
163 
166 
168 template<>
170 {
171  public:
173  double kw = 0,
174  double kh = 0)
175  : u(u), kw(kw), kh(kh), n(0), m(0), it(NULL), isStart(false)
176  {}
177 
179  {
180  delete it;
181  }
182 
183  void Initialize(const arma::sp_mat& dataset, const size_t rank)
184  {
185  (void)rank;
186  n = dataset.n_rows;
187  m = dataset.n_cols;
188 
189  it = new arma::sp_mat::const_iterator(dataset.begin());
190  isStart = true;
191  }
192 
202  inline void WUpdate(const arma::sp_mat& V,
203  arma::mat& W,
204  const arma::mat& H)
205  {
206  if (!isStart)
207  ++(*it);
208  else isStart = false;
209 
210  if (*it == V.end())
211  {
212  delete it;
213  it = new arma::sp_mat::const_iterator(V.begin());
214  }
215 
216  size_t currentUserIndex = it->col();
217  size_t currentItemIndex = it->row();
218 
219  arma::mat deltaW(1, W.n_cols);
220  deltaW.zeros();
221 
222  deltaW += (**it - arma::dot(W.row(currentItemIndex),
223  H.col(currentUserIndex))) * arma::trans(H.col(currentUserIndex));
224  if (kw != 0) deltaW -= kw * W.row(currentItemIndex);
225 
226  W.row(currentItemIndex) += u*deltaW;
227  }
228 
238  inline void HUpdate(const arma::sp_mat& /* V */,
239  const arma::mat& W,
240  arma::mat& H)
241  {
242  arma::mat deltaH(H.n_rows, 1);
243  deltaH.zeros();
244 
245  size_t currentUserIndex = it->col();
246  size_t currentItemIndex = it->row();
247 
248  deltaH += (**it - arma::dot(W.row(currentItemIndex),
249  H.col(currentUserIndex))) * arma::trans(W.row(currentItemIndex));
250  if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
251 
252  H.col(currentUserIndex) += u * deltaH;
253  }
254 
255  private:
256  double u;
257  double kw;
258  double kh;
259 
260  size_t n;
261  size_t m;
262 
263  arma::sp_mat dummy;
264  arma::sp_mat::const_iterator* it;
265 
266  bool isStart;
267 }; // class SVDCompleteIncrementalLearning
268 
269 } // namespace amf
270 } // namespace mlpack
271 
272 #endif
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
Initialize the SVDCompleteIncrementalLearning class with the given parameters.
This class computes SVD using complete incremental batch learning, as described in the following pape...
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void Initialize(const arma::sp_mat &dataset, const size_t rank)
void HUpdate(const arma::sp_mat &, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.