simple_cv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_SIMPLE_CV_HPP
13 #define MLPACK_CORE_CV_SIMPLE_CV_HPP
14 
17 
18 namespace mlpack {
19 namespace cv {
20 
60 template<typename MLAlgorithm,
61  typename Metric,
62  typename MatType = arma::mat,
63  typename PredictionsType =
65  typename WeightsType =
66  typename MetaInfoExtractor<MLAlgorithm, MatType,
67  PredictionsType>::WeightsType>
68 class SimpleCV
69 {
70  public:
84  template<typename MatInType, typename PredictionsInType>
85  SimpleCV(const double validationSize,
86  MatInType&& xs,
87  PredictionsInType&& ys);
88 
101  template<typename MatInType, typename PredictionsInType>
102  SimpleCV(const double validationSize,
103  MatInType&& xs,
104  PredictionsInType&& ys,
105  const size_t numClasses);
106 
121  template<typename MatInType, typename PredictionsInType>
122  SimpleCV(const double validationSize,
123  MatInType&& xs,
124  const data::DatasetInfo& datasetInfo,
125  PredictionsInType&& ys,
126  const size_t numClasses);
127 
143  template<typename MatInType,
144  typename PredictionsInType,
145  typename WeightsInType>
146  SimpleCV(const double validationSize,
147  MatInType&& xs,
148  PredictionsInType&& ys,
149  WeightsInType&& weights);
150 
166  template<typename MatInType,
167  typename PredictionsInType,
168  typename WeightsInType>
169  SimpleCV(const double validationSize,
170  MatInType&& xs,
171  PredictionsInType&& ys,
172  const size_t numClasses,
173  WeightsInType&& weights);
174 
191  template<typename MatInType,
192  typename PredictionsInType,
193  typename WeightsInType>
194  SimpleCV(const double validationSize,
195  MatInType&& xs,
196  const data::DatasetInfo& datasetInfo,
197  PredictionsInType&& ys,
198  const size_t numClasses,
199  WeightsInType&& weights);
200 
208  template<typename... MLAlgorithmArgs>
209  double Evaluate(const MLAlgorithmArgs&... args);
210 
212  MLAlgorithm& Model();
213 
214  private:
217 
219  Base base;
220 
222  MatType xs;
224  PredictionsType ys;
226  WeightsType weights;
227 
229  MatType trainingXs;
231  PredictionsType trainingYs;
233  WeightsType trainingWeights;
234 
236  MatType validationXs;
238  PredictionsType validationYs;
239 
241  std::unique_ptr<MLAlgorithm> modelPtr;
242 
247  template<typename MatInType,
248  typename PredictionsInType>
249  SimpleCV(Base&& base,
250  const double validationSize,
251  MatInType&& xs,
252  PredictionsInType&& ys);
253 
258  template<typename MatInType,
259  typename PredictionsInType,
260  typename WeightsInType>
261  SimpleCV(Base&& base,
262  const double validationSize,
263  MatInType&& xs,
264  PredictionsInType&& ys,
265  WeightsInType&& weights);
266 
270  size_t CalculateAndAssertNumberOfTrainingPoints(const double validationSize);
271 
275  template<typename ElementType>
276  arma::Mat<ElementType> GetSubset(arma::Mat<ElementType>& m,
277  const size_t firstCol,
278  const size_t lastCol);
279 
283  template<typename ElementType>
284  arma::Row<ElementType> GetSubset(arma::Row<ElementType>& r,
285  const size_t firstCol,
286  const size_t lastCol);
287 
291  template<typename... MLAlgorithmArgs,
292  bool Enabled = !Base::MIE::SupportsWeights,
293  typename = typename std::enable_if<Enabled>::type>
294  double TrainAndEvaluate(const MLAlgorithmArgs&... args);
295 
299  template<typename... MLAlgorithmArgs,
300  bool Enabled = Base::MIE::SupportsWeights,
301  typename = typename std::enable_if<Enabled>::type,
302  typename = void>
303  double TrainAndEvaluate(const MLAlgorithmArgs&... args);
304 };
305 
306 } // namespace cv
307 } // namespace mlpack
308 
309 // Include implementation
310 #include "simple_cv_impl.hpp"
311 
312 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
MLAlgorithm & Model()
Access and modify the last trained model.
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Definition: simple_cv.hpp:68
Linear algebra utility functions, generally performed on matrices or vectors.
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric...
static const bool SupportsWeights
An indication whether MLAlgorithm supports weighted learning.
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms...
typename Select< TF1, TF2, TF3, TF4, TF5 >::Type::PredictionsType PredictionsType
The type of predictions used in MLAlgorithm.
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39