12 #ifndef MLPACK_CORE_CV_SIMPLE_CV_HPP 13 #define MLPACK_CORE_CV_SIMPLE_CV_HPP 60 template<
typename MLAlgorithm,
62 typename MatType = arma::mat,
63 typename PredictionsType =
65 typename WeightsType =
66 typename MetaInfoExtractor<MLAlgorithm, MatType,
67 PredictionsType>::WeightsType>
84 template<
typename MatInType,
typename PredictionsInType>
85 SimpleCV(
const double validationSize,
87 PredictionsInType&& ys);
101 template<
typename MatInType,
typename PredictionsInType>
102 SimpleCV(
const double validationSize,
104 PredictionsInType&& ys,
105 const size_t numClasses);
121 template<
typename MatInType,
typename PredictionsInType>
122 SimpleCV(
const double validationSize,
125 PredictionsInType&& ys,
126 const size_t numClasses);
143 template<
typename MatInType,
144 typename PredictionsInType,
145 typename WeightsInType>
146 SimpleCV(
const double validationSize,
148 PredictionsInType&& ys,
149 WeightsInType&& weights);
166 template<
typename MatInType,
167 typename PredictionsInType,
168 typename WeightsInType>
169 SimpleCV(
const double validationSize,
171 PredictionsInType&& ys,
172 const size_t numClasses,
173 WeightsInType&& weights);
191 template<
typename MatInType,
192 typename PredictionsInType,
193 typename WeightsInType>
194 SimpleCV(
const double validationSize,
197 PredictionsInType&& ys,
198 const size_t numClasses,
199 WeightsInType&& weights);
208 template<
typename... MLAlgorithmArgs>
209 double Evaluate(
const MLAlgorithmArgs&... args);
212 MLAlgorithm&
Model();
231 PredictionsType trainingYs;
233 WeightsType trainingWeights;
236 MatType validationXs;
238 PredictionsType validationYs;
241 std::unique_ptr<MLAlgorithm> modelPtr;
247 template<
typename MatInType,
248 typename PredictionsInType>
250 const double validationSize,
252 PredictionsInType&& ys);
258 template<
typename MatInType,
259 typename PredictionsInType,
260 typename WeightsInType>
262 const double validationSize,
264 PredictionsInType&& ys,
265 WeightsInType&& weights);
270 size_t CalculateAndAssertNumberOfTrainingPoints(
const double validationSize);
275 template<
typename ElementType>
276 arma::Mat<ElementType> GetSubset(arma::Mat<ElementType>& m,
277 const size_t firstCol,
278 const size_t lastCol);
283 template<
typename ElementType>
284 arma::Row<ElementType> GetSubset(arma::Row<ElementType>& r,
285 const size_t firstCol,
286 const size_t lastCol);
291 template<
typename... MLAlgorithmArgs,
293 typename =
typename std::enable_if<Enabled>::type>
294 double TrainAndEvaluate(
const MLAlgorithmArgs&... args);
299 template<
typename... MLAlgorithmArgs,
301 typename =
typename std::enable_if<Enabled>::type,
303 double TrainAndEvaluate(
const MLAlgorithmArgs&... args);
310 #include "simple_cv_impl.hpp" Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
MLAlgorithm & Model()
Access and modify the last trained model.
SimpleCV splits data into two sets - training and validation sets - and then runs training on the tra...
Linear algebra utility functions, generally performed on matrices or vectors.
double Evaluate(const MLAlgorithmArgs &... args)
Train on the training set and assess performance on the validation set by using the class Metric...
SimpleCV(const double validationSize, MatInType &&xs, PredictionsInType &&ys)
This constructor can be used for regression algorithms and for binary classification algorithms...
An auxiliary class for cross-validation.