svd_batch_learning.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
13 #define MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace amf {
19 
42 {
43  public:
52  SVDBatchLearning(double u = 0.0002,
53  double kw = 0,
54  double kh = 0,
55  double momentum = 0.9)
56  : u(u), kw(kw), kh(kh), momentum(momentum)
57  {
58  // empty constructor
59  }
60 
68  template<typename MatType>
69  void Initialize(const MatType& dataset, const size_t rank)
70  {
71  const size_t n = dataset.n_rows;
72  const size_t m = dataset.n_cols;
73 
74  mW.zeros(n, rank);
75  mH.zeros(rank, m);
76  }
77 
87  template<typename MatType>
88  inline void WUpdate(const MatType& V,
89  arma::mat& W,
90  const arma::mat& H)
91  {
92  size_t n = V.n_rows;
93  size_t m = V.n_cols;
94 
95  size_t r = W.n_cols;
96 
97  // initialize the momentum of this iteration.
98  mW = momentum * mW;
99 
100  // Compute the step.
101  arma::mat deltaW;
102  deltaW.zeros(n, r);
103  for (size_t i = 0; i < n; ++i)
104  {
105  for (size_t j = 0; j < m; ++j)
106  {
107  const double val = V(i, j);
108  if (val != 0)
109  deltaW.row(i) += (val - arma::dot(W.row(i), H.col(j))) *
110  arma::trans(H.col(j));
111  }
112  // Add regularization.
113  if (kw != 0)
114  deltaW.row(i) -= kw * W.row(i);
115  }
116 
117  // Add the step to the momentum.
118  mW += u * deltaW;
119  // Add the momentum to the W matrix.
120  W += mW;
121  }
122 
132  template<typename MatType>
133  inline void HUpdate(const MatType& V,
134  const arma::mat& W,
135  arma::mat& H)
136  {
137  size_t n = V.n_rows;
138  size_t m = V.n_cols;
139 
140  size_t r = W.n_cols;
141 
142  // Initialize the momentum of this iteration.
143  mH = momentum * mH;
144 
145  // Compute the step.
146  arma::mat deltaH;
147  deltaH.zeros(r, m);
148  for (size_t j = 0; j < m; ++j)
149  {
150  for (size_t i = 0; i < n; ++i)
151  {
152  const double val = V(i, j);
153  if (val != 0)
154  deltaH.col(j) += (val - arma::dot(W.row(i), H.col(j))) * W.row(i).t();
155  }
156  // Add regularization.
157  if (kh != 0)
158  deltaH.col(j) -= kh * H.col(j);
159  }
160 
161  // Add this step to the momentum.
162  mH += u * deltaH;
163  // Add the momentum to H.
164  H += mH;
165  }
166 
168  template<typename Archive>
169  void serialize(Archive& ar, const uint32_t /* version */)
170  {
171  ar(CEREAL_NVP(u));
172  ar(CEREAL_NVP(kw));
173  ar(CEREAL_NVP(kh));
174  ar(CEREAL_NVP(momentum));
175  ar(CEREAL_NVP(mW));
176  ar(CEREAL_NVP(mH));
177  }
178 
179  private:
181  double u;
183  double kw;
185  double kh;
187  double momentum;
188 
190  arma::mat mW;
192  arma::mat mH;
193 }; // class SVDBatchLearning
194 
197 
201 template<>
202 inline void SVDBatchLearning::WUpdate<arma::sp_mat>(const arma::sp_mat& V,
203  arma::mat& W,
204  const arma::mat& H)
205 {
206  const size_t n = V.n_rows;
207  const size_t r = W.n_cols;
208 
209  mW = momentum * mW;
210 
211  arma::mat deltaW;
212  deltaW.zeros(n, r);
213 
214  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
215  {
216  const size_t row = it.row();
217  const size_t col = it.col();
218  deltaW.row(it.row()) += (*it - arma::dot(W.row(row), H.col(col))) *
219  arma::trans(H.col(col));
220  }
221 
222  if (kw != 0)
223  deltaW -= kw * W;
224 
225  mW += u * deltaW;
226  W += mW;
227 }
228 
229 template<>
230 inline void SVDBatchLearning::HUpdate<arma::sp_mat>(const arma::sp_mat& V,
231  const arma::mat& W,
232  arma::mat& H)
233 {
234  const size_t m = V.n_cols;
235  const size_t r = W.n_cols;
236 
237  mH = momentum * mH;
238 
239  arma::mat deltaH;
240  deltaH.zeros(r, m);
241 
242  for (arma::sp_mat::const_iterator it = V.begin(); it != V.end(); ++it)
243  {
244  const size_t row = it.row();
245  const size_t col = it.col();
246  deltaH.col(col) += (*it - arma::dot(W.row(row), H.col(col))) *
247  W.row(row).t();
248  }
249 
250  if (kh != 0)
251  deltaH -= kh * H;
252 
253  mH += u * deltaH;
254  H += mH;
255 }
256 
257 } // namespace amf
258 } // namespace mlpack
259 
260 #endif // MLPACK_METHODS_AMF_UPDATE_RULES_SVD_BATCH_LEARNING_HPP
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the SVDBatch object.
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements SVD batch learning with momentum.
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
void Initialize(const MatType &dataset, const size_t rank)
Initialize parameters before factorization.
SVDBatchLearning(double u=0.0002, double kw=0, double kh=0, double momentum=0.9)
SVD Batch learning constructor.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.