12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP 13 #define MLPACK_CORE_CV_K_FOLD_CV_HPP 57 template<
typename MLAlgorithm,
59 typename MatType = arma::mat,
60 typename PredictionsType =
62 typename WeightsType =
63 typename MetaInfoExtractor<MLAlgorithm, MatType,
64 PredictionsType>::WeightsType>
80 const PredictionsType& ys,
81 const bool shuffle =
true);
94 const PredictionsType& ys,
95 const size_t numClasses,
96 const bool shuffle =
true);
112 const PredictionsType& ys,
113 const size_t numClasses,
114 const bool shuffle =
true);
129 const PredictionsType& ys,
130 const WeightsType& weights,
131 const bool shuffle =
true);
146 const PredictionsType& ys,
147 const size_t numClasses,
148 const WeightsType& weights,
149 const bool shuffle =
true);
166 const PredictionsType& ys,
167 const size_t numClasses,
168 const WeightsType& weights,
169 const bool shuffle =
true);
177 template<
typename... MLAlgorithmArgs>
178 double Evaluate(
const MLAlgorithmArgs& ...args);
181 MLAlgorithm&
Model();
193 typename =
typename std::enable_if<Enabled>::type>
201 typename =
typename std::enable_if<Enabled>::type,
226 std::unique_ptr<MLAlgorithm> modelPtr;
235 const PredictionsType& ys,
245 const PredictionsType& ys,
246 const WeightsType& weights,
253 template<
typename DataType>
254 void InitKFoldCVMat(
const DataType& source, DataType& destination);
259 template<
typename... MLAlgorithmArgs,
261 typename =
typename std::enable_if<Enabled>::type>
262 double TrainAndEvaluate(
const MLAlgorithmArgs& ...mlAlgorithmArgs);
267 template<
typename... MLAlgorithmArgs,
269 typename =
typename std::enable_if<Enabled>::type,
271 double TrainAndEvaluate(
const MLAlgorithmArgs& ...mlAlgorithmArgs);
279 inline size_t ValidationSubsetFirstCol(
const size_t i);
284 template<
typename ElementType>
285 inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
291 template<
typename ElementType>
292 inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
298 template<
typename ElementType>
299 inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
305 template<
typename ElementType>
306 inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
314 #include "k_fold_cv_impl.hpp" double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
This constructor can be used for regression algorithms and for binary classification algorithms...
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
An auxiliary class for cross-validation.
void Shuffle()
Shuffle the data.