cv_base.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_CV_CV_BASE_HPP
13 #define MLPACK_CORE_CV_CV_BASE_HPP
14 
16 
17 namespace mlpack {
18 namespace cv {
19 
35 template<typename MLAlgorithm,
36  typename MatType,
37  typename PredictionsType,
38  typename WeightsType>
39 class CVBase
40 {
41  public:
43  using MIE =
45 
50  CVBase();
51 
57  CVBase(const size_t numClasses);
58 
66  CVBase(const data::DatasetInfo& datasetInfo,
67  const size_t numClasses);
68 
72  static void AssertDataConsistency(const MatType& xs,
73  const PredictionsType& ys);
74 
79  static void AssertWeightsConsistency(const MatType& xs,
80  const WeightsType& weights);
81 
86  template<typename... MLAlgorithmArgs>
87  MLAlgorithm Train(const MatType& xs,
88  const PredictionsType& ys,
89  const MLAlgorithmArgs&... args);
90 
95  template<typename... MLAlgorithmArgs>
96  MLAlgorithm Train(const MatType& xs,
97  const PredictionsType& ys,
98  const WeightsType& weights,
99  const MLAlgorithmArgs&... args);
100 
101  private:
102  static_assert(MIE::IsSupported,
103  "The given MLAlgorithm is not supported by MetaInfoExtractor");
104 
106  const data::DatasetInfo datasetInfo;
108  const bool isDatasetInfoPassed;
110  size_t numClasses;
111 
115  static void AssertSizeEquality(const MatType& xs,
116  const PredictionsType& ys);
117 
121  static void AssertWeightsSize(const MatType& xs,
122  const WeightsType& weights);
123 
128  template<typename... MLAlgorithmArgs,
129  bool Enabled = !MIE::TakesNumClasses,
130  typename = typename std::enable_if<Enabled>::type>
131  MLAlgorithm TrainModel(const MatType& xs,
132  const PredictionsType& ys,
133  const MLAlgorithmArgs&... args);
134 
139  template<typename... MLAlgorithmArgs,
141  typename = typename std::enable_if<Enabled>::type,
142  typename = void>
143  MLAlgorithm TrainModel(const MatType& xs,
144  const PredictionsType& ys,
145  const MLAlgorithmArgs&... args);
146 
151  template<typename... MLAlgorithmArgs,
153  typename = typename std::enable_if<Enabled>::type,
154  typename = void,
155  typename = void>
156  MLAlgorithm TrainModel(const MatType& xs,
157  const PredictionsType& ys,
158  const MLAlgorithmArgs&... args);
159 
164  template<typename... MLAlgorithmArgs,
165  bool Enabled = !MIE::TakesNumClasses,
166  typename = typename std::enable_if<Enabled>::type>
167  MLAlgorithm TrainModel(const MatType& xs,
168  const PredictionsType& ys,
169  const WeightsType& weights,
170  const MLAlgorithmArgs&... args);
171 
176  template<typename... MLAlgorithmArgs,
178  typename = typename std::enable_if<Enabled>::type,
179  typename = void>
180  MLAlgorithm TrainModel(const MatType& xs,
181  const PredictionsType& ys,
182  const WeightsType& weights,
183  const MLAlgorithmArgs&... args);
184 
189  template<typename... MLAlgorithmArgs,
191  typename = typename std::enable_if<Enabled>::type,
192  typename = void,
193  typename = void>
194  MLAlgorithm TrainModel(const MatType& xs,
195  const PredictionsType& ys,
196  const WeightsType& weights,
197  const MLAlgorithmArgs&... args);
198 
208  template<bool ConstructableWithoutDatasetInfo,
209  typename... MLAlgorithmArgs,
210  typename =
211  typename std::enable_if<ConstructableWithoutDatasetInfo>::type>
212  MLAlgorithm TrainModel(const MatType& xs,
213  const PredictionsType& ys,
214  const MLAlgorithmArgs&... args);
215 
220  template<bool ConstructableWithoutDatasetInfo,
221  typename... MLAlgorithmArgs,
222  typename =
223  typename std::enable_if<!ConstructableWithoutDatasetInfo>::type,
224  typename = void>
225  MLAlgorithm TrainModel(const MatType& xs,
226  const PredictionsType& ys,
227  const MLAlgorithmArgs&... args);
228 };
229 
230 } // namespace cv
231 } // namespace mlpack
232 
233 // Include implementation
234 #include "cv_base_impl.hpp"
235 
236 #endif
static const bool IsSupported
An indication whether PredictionsType has been identified (i.e.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Linear algebra utility functions, generally performed on matrices or vectors.
MetaInfoExtractor is a tool for extracting meta information about a given machine learning algorithm...
MLAlgorithm Train(const MatType &xs, const PredictionsType &ys, const MLAlgorithmArgs &... args)
Train MLAlgorithm with given data points, predictions, and hyperparameters depending on what CVBase c...
static const bool TakesNumClasses
An indication whether MLAlgorithm takes the numClasses (size_t) parameter.
CVBase()
Assert that MLAlgorithm doesn&#39;t take any additional basic parameters like numClasses.
static void AssertWeightsConsistency(const MatType &xs, const WeightsType &weights)
Assert weighted learning is supported and there is the equal number of data points and weights...
static const bool TakesDatasetInfo
An indication whether MLAlgorithm takes a data::DatasetInfo parameter.
An auxiliary class for cross-validation.
Definition: cv_base.hpp:39
static void AssertDataConsistency(const MatType &xs, const PredictionsType &ys)
Assert there is the equal number of data points and predictions.