12 #ifndef MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP 13 #define MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP 71 const size_t userNum = cleanedData.n_cols;
72 a.set_size(userNum, userNum);
73 b.set_size(userNum, userNum);
93 template <
typename VectorType,
94 typename DecompositionPolicy>
96 const DecompositionPolicy& decomposition,
97 const size_t queryUser,
98 const arma::Col<size_t>& neighbors,
100 const arma::sp_mat& cleanedData)
102 if (weights.n_elem != neighbors.n_elem)
104 Log::Fatal <<
"The size of the first parameter (weights) should " 105 <<
"be set to the number of neighbors before calling GetWeights()." 109 const arma::mat& w = decomposition.W();
110 const arma::mat& h = decomposition.H();
111 const size_t itemNum = cleanedData.n_rows;
112 const size_t neighborNum = neighbors.size();
115 arma::mat coeff(neighborNum, neighborNum);
117 arma::vec constant(neighborNum);
119 arma::vec userRating(cleanedData.col(queryUser));
120 const size_t support = arma::accu(userRating != 0);
125 weights.fill(1.0 / neighbors.n_elem);
129 for (
size_t i = 0; i < neighborNum; ++i)
132 arma::vec iPrediction;
133 for (
size_t j = i; j < neighborNum; ++j)
135 if (a(neighbors(i), neighbors(j)) != 0)
138 coeff(i, j) = a(neighbors(i), neighbors(j));
139 coeff(j, i) = coeff(i, j);
144 if (iPrediction.size() == 0)
146 iPrediction = w * h.col(neighbors(i));
147 arma::vec jPrediction = w * h.col(neighbors(j));
148 coeff(i, j) = arma::dot(iPrediction, jPrediction) / itemNum;
149 if (coeff(i, j) == 0)
150 coeff(i, j) = std::numeric_limits<double>::min();
151 coeff(j, i) = coeff(i, j);
153 a(neighbors(i), neighbors(j)) = coeff(i, j);
154 a(neighbors(j), neighbors(i)) = coeff(i, j);
159 if (b(neighbors(i), queryUser) != 0)
161 constant(i) = b(neighbors(i), queryUser);
165 if (iPrediction.size() == 0)
167 iPrediction = w * h.col(neighbors(i));
168 constant(i) = arma::dot(iPrediction, userRating) / support;
169 if (constant(i) == 0)
170 constant(i) = std::numeric_limits<double>::min();
172 b(neighbors(i), queryUser) = constant(i);
175 weights = arma::solve(coeff, constant);
void GetWeights(VectorType &&weights, const DecompositionPolicy &decomposition, const size_t queryUser, const arma::Col< size_t > &neighbors, const arma::vec &, const arma::sp_mat &cleanedData)
The regression-based interpolation problem can be solved by a linear system of equations.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
RegressionInterpolation(const arma::sp_mat &cleanedData)
Use cleanedData to perform necessary preprocessing.
Implementation of regression-based interpolation method.
RegressionInterpolation()
Empty Constructor.