svd_incomplete_incremental_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_SVD_INCOMPLETE_INCREMENTAL_LEARNING_HPP
14 
15 namespace mlpack
16 {
17 namespace amf
18 {
19 
44 {
45  public:
54  double kw = 0,
55  double kh = 0)
56  : u(u), kw(kw), kh(kh), currentUserIndex(0)
57  {
58  // Nothing to do.
59  }
60 
69  template<typename MatType>
70  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
71  {
72  // Set the current user to 0.
73  currentUserIndex = 0;
74  }
75 
85  template<typename MatType>
86  inline void WUpdate(const MatType& V,
87  arma::mat& W,
88  const arma::mat& H)
89  {
90  arma::mat deltaW;
91  deltaW.zeros(V.n_rows, W.n_cols);
92 
93  // Iterate through all the rating by this user to update corresponding item
94  // feature feature vector.
95  for (size_t i = 0; i < V.n_rows; ++i)
96  {
97  const double val = V(i, currentUserIndex);
98  // Update only if the rating is non-zero.
99  if (val != 0)
100  {
101  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
102  H.col(currentUserIndex).t();
103  }
104  // Add regularization.
105  if (kw != 0)
106  deltaW.row(i) -= kw * W.row(i);
107  }
108 
109  W += u * deltaW;
110  }
111 
120  template<typename MatType>
121  inline void HUpdate(const MatType& V,
122  const arma::mat& W,
123  arma::mat& H)
124  {
125  arma::vec deltaH;
126  deltaH.zeros(H.n_rows);
127 
128  // Iterate through all the rating by this user to update corresponding item
129  // feature feature vector.
130  for (size_t i = 0; i < V.n_rows; ++i)
131  {
132  const double val = V(i, currentUserIndex);
133  // Update only if the rating is non-zero.
134  if (val != 0)
135  {
136  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
137  W.row(i).t();
138  }
139  }
140  // Add regularization.
141  if (kh != 0)
142  deltaH -= kh * H.col(currentUserIndex);
143 
144  // Update H matrix and move on to the next user.
145  H.col(currentUserIndex++) += u * deltaH;
146  currentUserIndex = currentUserIndex % V.n_cols;
147  }
148 
149  private:
151  double u;
153  double kw;
155  double kh;
156 
158  size_t currentUserIndex;
159 };
160 
163 
165 template<>
166 inline void SVDIncompleteIncrementalLearning::WUpdate<arma::sp_mat>(
167  const arma::sp_mat& V, arma::mat& W, const arma::mat& H)
168 {
169  arma::mat deltaW(V.n_rows, W.n_cols);
170  deltaW.zeros();
171  for (arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
172  it != V.end_col(currentUserIndex); ++it)
173  {
174  double val = *it;
175  size_t i = it.row();
176  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
177  arma::trans(H.col(currentUserIndex));
178  if (kw != 0) deltaW.row(i) -= kw * W.row(i);
179  }
180 
181  W += u*deltaW;
182 }
183 
184 template<>
185 inline void SVDIncompleteIncrementalLearning::HUpdate<arma::sp_mat>(
186  const arma::sp_mat& V, const arma::mat& W, arma::mat& H)
187 {
188  arma::mat deltaH(H.n_rows, 1);
189  deltaH.zeros();
190 
191  for (arma::sp_mat::const_iterator it = V.begin_col(currentUserIndex);
192  it != V.end_col(currentUserIndex); ++it)
193  {
194  double val = *it;
195  size_t i = it.row();
196  if ((val = V(i, currentUserIndex)) != 0)
197  {
198  deltaH += (val - arma::dot(W.row(i), H.col(currentUserIndex))) *
199  arma::trans(W.row(i));
200  }
201  }
202  if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
203 
204  H.col(currentUserIndex++) += u * deltaH;
205  currentUserIndex = currentUserIndex % V.n_cols;
206 }
207 
208 } // namespace amf
209 } // namespace mlpack
210 
211 #endif
This class computes SVD using incomplete incremental batch learning, as described in the following pa...
Linear algebra utility functions, generally performed on matrices or vectors.
SVDIncompleteIncrementalLearning(double u=0.001, double kw=0, double kh=0)
Initialize the parameters of SVDIncompleteIncrementalLearning.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.