cf_model.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_CF_CF_MODEL_HPP
14 #define MLPACK_METHODS_CF_CF_MODEL_HPP
15 
16 #include <mlpack/core.hpp>
17 #include "cf.hpp"
18 
19 namespace mlpack {
20 namespace cf {
21 
27 {
31 };
32 
38 {
42 };
43 
50 {
51  public:
54 
56  virtual CFWrapperBase* Clone() const = 0;
57 
59  virtual ~CFWrapperBase() { }
60 
62  virtual void Predict(const NeighborSearchTypes nsType,
63  const InterpolationTypes interpolationType,
64  const arma::Mat<size_t>& combinations,
65  arma::vec& predictions) = 0;
66 
68  virtual void GetRecommendations(
69  const NeighborSearchTypes nsType,
70  const InterpolationTypes interpolationType,
71  const size_t numRecs,
72  arma::Mat<size_t>& recommendations) = 0;
73 
75  virtual void GetRecommendations(
76  const NeighborSearchTypes nsType,
77  const InterpolationTypes interpolationType,
78  const size_t numRecs,
79  arma::Mat<size_t>& recommendations,
80  const arma::Col<size_t>& users) = 0;
81 };
82 
87 template<typename DecompositionPolicy, typename NormalizationPolicy>
88 class CFWrapper : public CFWrapperBase
89 {
90  protected:
92 
93  public:
96  CFWrapper() { }
97 
99  CFWrapper(const arma::mat& data,
100  const DecompositionPolicy& decomposition,
101  const size_t numUsersForSimilarity,
102  const size_t rank,
103  const size_t maxIterations,
104  const size_t minResidue,
105  const bool mit) :
106  cf(data,
107  decomposition,
108  numUsersForSimilarity,
109  rank,
110  maxIterations,
111  minResidue,
112  mit)
113  {
114  // Nothing else to do.
115  }
116 
118  virtual CFWrapper* Clone() const { return new CFWrapper(*this); }
119 
121  virtual ~CFWrapper() { }
122 
124  CFModelType& CF() { return cf; }
125 
127  virtual void Predict(const NeighborSearchTypes nsType,
128  const InterpolationTypes interpolationType,
129  const arma::Mat<size_t>& combinations,
130  arma::vec& predictions);
131 
133  virtual void GetRecommendations(
134  const NeighborSearchTypes nsType,
135  const InterpolationTypes interpolationType,
136  const size_t numRecs,
137  arma::Mat<size_t>& recommendations);
138 
140  virtual void GetRecommendations(
141  const NeighborSearchTypes nsType,
142  const InterpolationTypes interpolationType,
143  const size_t numRecs,
144  arma::Mat<size_t>& recommendations,
145  const arma::Col<size_t>& users);
146 
148  template<typename Archive>
149  void serialize(Archive& ar, const uint32_t /* version */)
150  {
151  ar(CEREAL_NVP(cf));
152  }
153 
154  protected:
156  CFModelType cf;
157 };
158 
162 class CFModel
163 {
164  public:
166  {
174  SVD_PLUS_PLUS
175  };
176 
178  {
183  Z_SCORE_NORMALIZATION
184  };
185 
186  private:
188  DecompositionTypes decompositionType;
190  NormalizationTypes normalizationType;
191 
197  CFWrapperBase* cf;
198 
199  public:
201  CFModel();
202 
204  CFModel(const CFModel& other);
205 
207  CFModel(CFModel&& other);
208 
210  CFModel& operator=(const CFModel& other);
211 
213  CFModel& operator=(CFModel&& other);
214 
216  ~CFModel();
217 
219  CFWrapperBase* CF() const { return cf; }
220 
223  {
224  return decompositionType;
225  }
228  {
229  return decompositionType;
230  }
231 
234  {
235  return normalizationType;
236  }
239  {
240  return normalizationType;
241  }
242 
244  void Train(const arma::mat& data,
245  const size_t numUsersForSimilarity,
246  const size_t rank,
247  const size_t maxIterations,
248  const double minResidue,
249  const bool mit);
250 
252  void Predict(const NeighborSearchTypes nsType,
253  const InterpolationTypes interpolationType,
254  const arma::Mat<size_t>& combinations,
255  arma::vec& predictions);
256 
258  void GetRecommendations(const NeighborSearchTypes nsType,
259  const InterpolationTypes interpolationType,
260  const size_t numRecs,
261  arma::Mat<size_t>& recommendations,
262  const arma::Col<size_t>& users);
263 
265  void GetRecommendations(const NeighborSearchTypes nsType,
266  const InterpolationTypes interpolationType,
267  const size_t numRecs,
268  arma::Mat<size_t>& recommendations);
269 
271  template<typename Archive>
272  void serialize(Archive& ar, const uint32_t /* version */);
273 };
274 
275 } // namespace cf
276 } // namespace mlpack
277 
278 // Include implementation.
279 #include "cf_model_impl.hpp"
280 
281 #endif
CFWrapperBase()
Create the object. The base class has nothing to hold.
Definition: cf_model.hpp:53
Linear algebra utility functions, generally performed on matrices or vectors.
CFWrapper(const arma::mat &data, const DecompositionPolicy &decomposition, const size_t numUsersForSimilarity, const size_t rank, const size_t maxIterations, const size_t minResidue, const bool mit)
Create the CFWrapper object, initializing the held CF object.
Definition: cf_model.hpp:99
The CFWrapperBase class provides a unified interface that can be used by the CFModel class to interac...
Definition: cf_model.hpp:49
CFWrapperBase * CF() const
Get the CFWrapperBase object. (Be careful!)
Definition: cf_model.hpp:219
virtual void GetRecommendations(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const size_t numRecs, arma::Mat< size_t > &recommendations)=0
Compute recommendations for all users.
DecompositionTypes & DecompositionType()
Set the decomposition type.
Definition: cf_model.hpp:227
NormalizationTypes & NormalizationType()
Set the normalization type.
Definition: cf_model.hpp:238
CFModelType cf
This is the CF object that we are wrapping.
Definition: cf_model.hpp:156
virtual CFWrapper * Clone() const
Clone the CFWrapper object. This handles polymorphism correctly.
Definition: cf_model.hpp:118
The model to save to disk.
Definition: cf_model.hpp:162
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
CFWrapper()
Create the CFWrapper object, using default parameters to initialize the held CF object.
Definition: cf_model.hpp:96
InterpolationTypes
InterpolationTypes contains the set of InterpolationPolicy classes that are usable by CFModel at pred...
Definition: cf_model.hpp:37
const NormalizationTypes & NormalizationType() const
Get the normalization type.
Definition: cf_model.hpp:233
virtual CFWrapperBase * Clone() const =0
Make a copy of the object.
virtual ~CFWrapperBase()
Delete the object.
Definition: cf_model.hpp:59
void serialize(Archive &ar, const uint32_t)
Serialize the model.
Definition: cf_model.hpp:149
NeighborSearchTypes
NeighborSearchTypes contains the set of NeighborSearchPolicy classes that are usable by CFModel at pr...
Definition: cf_model.hpp:26
virtual ~CFWrapper()
Destroy the CFWrapper object.
Definition: cf_model.hpp:121
The CFWrapper class wraps the functionality of all CF types.
Definition: cf_model.hpp:88
virtual void Predict(const NeighborSearchTypes nsType, const InterpolationTypes interpolationType, const arma::Mat< size_t > &combinations, arma::vec &predictions)=0
Compute predictions for users.
const DecompositionTypes & DecompositionType() const
Get the decomposition type.
Definition: cf_model.hpp:222
CFModelType & CF()
Get the CFType object.
Definition: cf_model.hpp:124
CFType< DecompositionPolicy, NormalizationPolicy > CFModelType
Definition: cf_model.hpp:91