12 #ifndef MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP 13 #define MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP 55 double momentum = 0.9)
56 : u(u), kw(kw), kh(kh), momentum(momentum)
68 template<
typename MatType>
69 void Initialize(
const MatType& dataset,
const size_t rank)
71 const size_t n = dataset.n_rows;
72 const size_t m = dataset.n_cols;
87 template<
typename MatType>
103 for (
size_t i = 0; i < n; ++i)
105 for (
size_t j = 0; j < m; ++j)
107 const double val = V(i, j);
109 deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
110 arma::trans(H.col(j));
114 deltaW.row(i) -= kw * W.row(i);
132 template<
typename MatType>
148 for (
size_t j = 0; j < m; ++j)
150 for (
size_t i = 0; i < n; ++i)
152 const double val = V(i, j);
154 deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
158 deltaH.col(j) -= kh * H.col(j);
168 template<
typename Archive>
174 ar(CEREAL_NVP(momentum));
202 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(
const arma::sp_mat& V,
206 const size_t n = V.n_rows;
207 const size_t r = W.n_cols;
214 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
216 const size_t row = it.row();
217 const size_t col = it.col();
218 deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
219 arma::trans(H.col(col));
230 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(
const arma::sp_mat& V,
234 const size_t m = V.n_cols;
235 const size_t r = W.n_cols;
242 for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
244 const size_t row = it.row();
245 const size_t col = it.col();
246 deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
260 #endif // MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the SVDBatch object.
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements SVD batch learning with momentum.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.