|
| KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const bool shuffle=true) |
| This constructor can be used for regression algorithms and for binary classification algorithms. More...
|
|
| KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true) |
| This constructor can be used for multiclass classification algorithms. More...
|
|
| KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const bool shuffle=true) |
| This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter. More...
|
|
| KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const WeightsType &weights, const bool shuffle=true) |
| This constructor can be used for regression and binary classification algorithms that support weighted learning. More...
|
|
| KFoldCV (const size_t k, const MatType &xs, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true) |
| This constructor can be used for multiclass classification algorithms that support weighted learning. More...
|
|
| KFoldCV (const size_t k, const MatType &xs, const data::DatasetInfo &datasetInfo, const PredictionsType &ys, const size_t numClasses, const WeightsType &weights, const bool shuffle=true) |
| This constructor can be used for multiclass classification algorithms that can take a data::DatasetInfo parameter and support weighted learning. More...
|
|
template<typename... MLAlgorithmArgs> |
double | Evaluate (const MLAlgorithmArgs &...args) |
| Run k-fold cross-validation. More...
|
|
MLAlgorithm & | Model () |
| Access and modify a model from the last run of k-fold cross-validation. More...
|
|
template<bool Enabled = !Base::MIE::SupportsWeights, typename = typename std::enable_if<Enabled>::type> |
void | Shuffle () |
| Shuffle the data. More...
|
|
template<bool Enabled = Base::MIE::SupportsWeights, typename = typename std::enable_if<Enabled>::type, typename = void> |
void | Shuffle () |
| Shuffle the data. More...
|
|
template<typename MLAlgorithm, typename Metric, typename MatType = arma::mat, typename PredictionsType = typename MetaInfoExtractor<MLAlgorithm, MatType>::PredictionsType, typename WeightsType = typename MetaInfoExtractor<MLAlgorithm, MatType, PredictionsType>::WeightsType>
class mlpack::cv::KFoldCV< MLAlgorithm, Metric, MatType, PredictionsType, WeightsType >
The class KFoldCV implements k-fold cross-validation for regression and classification algorithms.
To construct a KFoldCV object you need to pass the k parameter and arguments that specify data. For example, you can run 10-fold cross-validation for SoftmaxRegression in the following way.
arma::mat data = arma::randu<arma::mat>(5, 100);
arma::Row<size_t> labels =
arma::randi<arma::Row<size_t>>(100, arma::distr_param(0, 4));
size_t numClasses = 5;
KFoldCV<SoftmaxRegression<>, Accuracy> cv(10, data, labels, numClasses);
double lambda = 0.1;
double softmaxAccuracy = cv.Evaluate(lambda);
Before calling Evaluate()
, it is possible to shuffle the data by calling the Shuffle()
function. Shuffling is performed at construction time if the parameter shuffle
is set to true
in the constructor.
- Template Parameters
-
MLAlgorithm | A machine learning algorithm. |
Metric | A metric to assess the quality of a trained model. |
MatType | The type of data. |
PredictionsType | The type of predictions (should be passed when the predictions type is a template parameter in Train methods of MLAlgorithm). |
WeightsType | The type of weights (should be passed when weighted learning is supported, and the weights type is a template parameter in Train methods of MLAlgorithm). |
Definition at line 65 of file k_fold_cv.hpp.