12 #ifndef MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP 13 #define MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP 44 template <
class MatType>
59 : u(u), kw(kw), kh(kh), currentUserIndex(0), currentItemIndex(0)
92 deltaW.zeros(1, W.n_cols);
97 const double val = V(currentItemIndex, currentUserIndex);
101 deltaW += (val - arma::dot(W.row(currentItemIndex),
102 H.col(currentUserIndex))) * H.col(currentUserIndex).t();
106 deltaW -= kw * W.row(currentItemIndex);
111 W.row(currentItemIndex) += u * deltaW;
128 deltaH.zeros(H.n_rows, 1);
130 const double val = V(currentItemIndex, currentUserIndex);
133 deltaH += (val - arma::dot(W.row(currentItemIndex),
134 H.col(currentUserIndex))) * W.row(currentItemIndex).t();
137 deltaH -= kh * H.col(currentUserIndex);
140 currentUserIndex = currentUserIndex + 1;
141 if (currentUserIndex == V.n_rows)
143 currentUserIndex = 0;
144 currentItemIndex = (currentItemIndex + 1) % V.n_cols;
147 H.col(currentUserIndex++) += u * deltaH;
159 size_t currentUserIndex;
161 size_t currentItemIndex;
175 : u(u), kw(kw), kh(kh), n(0), m(0), it(NULL), isStart(false)
183 void Initialize(
const arma::sp_mat& dataset,
const size_t rank)
189 it =
new arma::sp_mat::const_iterator(dataset.begin());
208 else isStart =
false;
213 it =
new arma::sp_mat::const_iterator(V.begin());
216 size_t currentUserIndex = it->col();
217 size_t currentItemIndex = it->row();
219 arma::mat deltaW(1, W.n_cols);
222 deltaW += (**it - arma::dot(W.row(currentItemIndex),
223 H.col(currentUserIndex))) * arma::trans(H.col(currentUserIndex));
224 if (kw != 0) deltaW -= kw * W.row(currentItemIndex);
226 W.row(currentItemIndex) += u*deltaW;
242 arma::mat deltaH(H.n_rows, 1);
245 size_t currentUserIndex = it->col();
246 size_t currentItemIndex = it->row();
248 deltaH += (**it - arma::dot(W.row(currentItemIndex),
249 H.col(currentUserIndex))) * arma::trans(W.row(currentItemIndex));
250 if (kh != 0) deltaH -= kh * H.col(currentUserIndex);
252 H.col(currentUserIndex) += u * deltaH;
264 arma::sp_mat::const_iterator* it;
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
Initialize the SVDCompleteIncrementalLearning class with the given parameters.
This class computes SVD using complete incremental batch learning, as described in the following pape...
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
SVDCompleteIncrementalLearning(double u=0.01, double kw=0, double kh=0)
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
~SVDCompleteIncrementalLearning()
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void Initialize(const arma::sp_mat &dataset, const size_t rank)
void HUpdate(const arma::sp_mat &, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.