k_fold_cv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_K_FOLD_CV_HPP
13 #define MLPACK_CORE_CV_K_FOLD_CV_HPP
14 
17 
18 namespace mlpack {
19 namespace cv {
20 
57 template<typename MLAlgorithm,
58  typename Metric,
59  typename MatType = arma::mat,
60  typename PredictionsType =
62  typename WeightsType =
63  typename MetaInfoExtractor<MLAlgorithm, MatType,
64  PredictionsType>::WeightsType>
65 class KFoldCV
66 {
67  public:
78  KFoldCV(const size_t k,
79  const MatType& xs,
80  const PredictionsType& ys,
81  const bool shuffle = true);
82 
92  KFoldCV(const size_t k,
93  const MatType& xs,
94  const PredictionsType& ys,
95  const size_t numClasses,
96  const bool shuffle = true);
97 
109  KFoldCV(const size_t k,
110  const MatType& xs,
111  const data::DatasetInfo& datasetInfo,
112  const PredictionsType& ys,
113  const size_t numClasses,
114  const bool shuffle = true);
115 
127  KFoldCV(const size_t k,
128  const MatType& xs,
129  const PredictionsType& ys,
130  const WeightsType& weights,
131  const bool shuffle = true);
132 
144  KFoldCV(const size_t k,
145  const MatType& xs,
146  const PredictionsType& ys,
147  const size_t numClasses,
148  const WeightsType& weights,
149  const bool shuffle = true);
150 
163  KFoldCV(const size_t k,
164  const MatType& xs,
165  const data::DatasetInfo& datasetInfo,
166  const PredictionsType& ys,
167  const size_t numClasses,
168  const WeightsType& weights,
169  const bool shuffle = true);
170 
177  template<typename... MLAlgorithmArgs>
178  double Evaluate(const MLAlgorithmArgs& ...args);
179 
181  MLAlgorithm& Model();
182 
183  private:
186 
187  public:
192  template<bool Enabled = !Base::MIE::SupportsWeights,
193  typename = typename std::enable_if<Enabled>::type>
194  void Shuffle();
195 
200  template<bool Enabled = Base::MIE::SupportsWeights,
201  typename = typename std::enable_if<Enabled>::type,
202  typename = void>
203  void Shuffle();
204 
205  private:
207  Base base;
208 
210  const size_t k;
211 
213  MatType xs;
215  PredictionsType ys;
217  WeightsType weights;
218 
220  size_t lastBinSize;
221 
223  size_t binSize;
224 
226  std::unique_ptr<MLAlgorithm> modelPtr;
227 
232  KFoldCV(Base&& base,
233  const size_t k,
234  const MatType& xs,
235  const PredictionsType& ys,
236  const bool shuffle);
237 
242  KFoldCV(Base&& base,
243  const size_t k,
244  const MatType& xs,
245  const PredictionsType& ys,
246  const WeightsType& weights,
247  const bool shuffle);
248 
253  template<typename DataType>
254  void InitKFoldCVMat(const DataType& source, DataType& destination);
255 
259  template<typename... MLAlgorithmArgs,
260  bool Enabled = !Base::MIE::SupportsWeights,
261  typename = typename std::enable_if<Enabled>::type>
262  double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
263 
267  template<typename... MLAlgorithmArgs,
268  bool Enabled = Base::MIE::SupportsWeights,
269  typename = typename std::enable_if<Enabled>::type,
270  typename = void>
271  double TrainAndEvaluate(const MLAlgorithmArgs& ...mlAlgorithmArgs);
272 
279  inline size_t ValidationSubsetFirstCol(const size_t i);
280 
284  template<typename ElementType>
285  inline arma::Mat<ElementType> GetTrainingSubset(arma::Mat<ElementType>& m,
286  const size_t i);
287 
291  template<typename ElementType>
292  inline arma::Row<ElementType> GetTrainingSubset(arma::Row<ElementType>& r,
293  const size_t i);
294 
298  template<typename ElementType>
299  inline arma::Mat<ElementType> GetValidationSubset(arma::Mat<ElementType>& m,
300  const size_t i);
301 
305  template<typename ElementType>
306  inline arma::Row<ElementType> GetValidationSubset(arma::Row<ElementType>& r,
307  const size_t i);
308 };
309 
310 } // namespace cv
311 } // namespace mlpack
312 
313 // Include implementation
314 #include "k_fold_cv_impl.hpp"
315 
316 #endif
double Evaluate(const MLAlgorithmArgs &...args)
Run k-fold cross-validation.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
MLAlgorithm & Model()
Access and modify a model from the last run of k-fold cross-validation.
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
KFoldCV(const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true)
This constructor can be used for regression algorithms and for binary classification algorithms...
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms...
Definition: k_fold_cv.hpp:65
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
void Shuffle()
Shuffle the data.