12 #ifndef MLPACK_CORE_CV_CV_BASE_HPP 13 #define MLPACK_CORE_CV_CV_BASE_HPP 35 template<
typename MLAlgorithm,
37 typename PredictionsType,
57 CVBase(
const size_t numClasses);
67 const size_t numClasses);
73 const PredictionsType& ys);
80 const WeightsType& weights);
86 template<
typename... MLAlgorithmArgs>
87 MLAlgorithm
Train(
const MatType& xs,
88 const PredictionsType& ys,
89 const MLAlgorithmArgs&... args);
95 template<
typename... MLAlgorithmArgs>
96 MLAlgorithm
Train(
const MatType& xs,
97 const PredictionsType& ys,
98 const WeightsType& weights,
99 const MLAlgorithmArgs&... args);
103 "The given MLAlgorithm is not supported by MetaInfoExtractor");
108 const bool isDatasetInfoPassed;
115 static void AssertSizeEquality(
const MatType& xs,
116 const PredictionsType& ys);
121 static void AssertWeightsSize(
const MatType& xs,
122 const WeightsType& weights);
128 template<
typename... MLAlgorithmArgs,
130 typename =
typename std::enable_if<Enabled>::type>
131 MLAlgorithm TrainModel(
const MatType& xs,
132 const PredictionsType& ys,
133 const MLAlgorithmArgs&... args);
139 template<
typename... MLAlgorithmArgs,
141 typename =
typename std::enable_if<Enabled>::type,
143 MLAlgorithm TrainModel(
const MatType& xs,
144 const PredictionsType& ys,
145 const MLAlgorithmArgs&... args);
151 template<
typename... MLAlgorithmArgs,
153 typename =
typename std::enable_if<Enabled>::type,
156 MLAlgorithm TrainModel(
const MatType& xs,
157 const PredictionsType& ys,
158 const MLAlgorithmArgs&... args);
164 template<
typename... MLAlgorithmArgs,
166 typename =
typename std::enable_if<Enabled>::type>
167 MLAlgorithm TrainModel(
const MatType& xs,
168 const PredictionsType& ys,
169 const WeightsType& weights,
170 const MLAlgorithmArgs&... args);
176 template<
typename... MLAlgorithmArgs,
178 typename =
typename std::enable_if<Enabled>::type,
180 MLAlgorithm TrainModel(
const MatType& xs,
181 const PredictionsType& ys,
182 const WeightsType& weights,
183 const MLAlgorithmArgs&... args);
189 template<
typename... MLAlgorithmArgs,
191 typename =
typename std::enable_if<Enabled>::type,
194 MLAlgorithm TrainModel(
const MatType& xs,
195 const PredictionsType& ys,
196 const WeightsType& weights,
197 const MLAlgorithmArgs&... args);
208 template<
bool ConstructableWithoutDatasetInfo,
209 typename... MLAlgorithmArgs,
211 typename std::enable_if<ConstructableWithoutDatasetInfo>::type>
212 MLAlgorithm TrainModel(
const MatType& xs,
213 const PredictionsType& ys,
214 const MLAlgorithmArgs&... args);
220 template<
bool ConstructableWithoutDatasetInfo,
221 typename... MLAlgorithmArgs,
223 typename std::enable_if<!ConstructableWithoutDatasetInfo>::type,
225 MLAlgorithm TrainModel(
const MatType& xs,
226 const PredictionsType& ys,
227 const MLAlgorithmArgs&... args);
234 #include "cv_base_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
CVBase()
Assert that MLAlgorithm doesn't take any additional basic parameters like numClasses.
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
An auxiliary class for cross-validation.
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.