13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP 14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP 26 template<
typename InputType>
30 const double testRatio,
31 const arma::uvec& order = arma::uvec())
33 const size_t testSize =
static_cast<size_t>(input.n_cols * testRatio);
34 const size_t trainSize = input.n_cols - testSize;
37 train.set_size(input.n_rows, trainSize);
38 test.set_size(input.n_rows, testSize);
41 if (!order.is_empty())
45 for (
size_t i = 0; i < trainSize; ++i)
46 train.col(i) = input.col(order(i));
48 if (trainSize < input.n_cols)
50 for (
size_t i = trainSize; i < input.n_cols; ++i)
51 test.col(i - trainSize) = input.col(order(i));
58 train = input.cols(0, trainSize - 1);
60 if (trainSize < input.n_cols)
61 test = input.cols(trainSize, input.n_cols - 1);
101 template<
typename T,
typename LabelsType,
104 const LabelsType& inputLabel,
105 arma::Mat<T>& trainData,
106 arma::Mat<T>& testData,
107 LabelsType& trainLabel,
108 LabelsType& testLabel,
109 const double testRatio,
110 const bool shuffleData =
true)
145 const bool typeCheck = (arma::is_Row<LabelsType>::value)
146 || (arma::is_Col<LabelsType>::value);
148 throw std::runtime_error(
"data::Split(): when stratified sampling is done, " 149 "labels must have type `arma::Row<>`!");
153 size_t trainSize = 0;
155 arma::uvec labelCounts;
156 arma::uvec testLabelCounts;
157 typename LabelsType::elem_type maxLabel = inputLabel.max();
159 labelCounts.zeros(maxLabel+1);
160 testLabelCounts.zeros(maxLabel+1);
162 for (
typename LabelsType::elem_type label : inputLabel)
163 ++labelCounts[label];
165 for (arma::uword labelCount : labelCounts)
167 testSize += floor(labelCount * testRatio);
168 trainSize += labelCount - floor(labelCount * testRatio);
171 trainData.set_size(input.n_rows, trainSize);
172 testData.set_size(input.n_rows, testSize);
173 trainLabel.set_size(inputLabel.n_rows, trainSize);
174 testLabel.set_size(inputLabel.n_rows, testSize);
178 arma::uvec order = arma::shuffle(
179 arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
181 for (arma::uword i : order)
183 typename LabelsType::elem_type label = inputLabel[i];
184 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
186 testLabelCounts[label] += 1;
187 testData.col(testIdx) = input.col(i);
188 testLabel[testIdx] = inputLabel[i];
193 trainData.col(trainIdx) = input.col(i);
194 trainLabel[trainIdx] = inputLabel[i];
201 for (arma::uword i = 0; i < input.n_cols; i++)
203 typename LabelsType::elem_type label = inputLabel[i];
204 if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
206 testLabelCounts[label] += 1;
207 testData.col(testIdx) = input.col(i);
208 testLabel[testIdx] = inputLabel[i];
213 trainData.col(trainIdx) = input.col(i);
214 trainLabel[trainIdx] = inputLabel[i];
254 template<
typename T,
typename LabelsType,
256 void Split(
const arma::Mat<T>& input,
257 const LabelsType& inputLabel,
258 arma::Mat<T>& trainData,
259 arma::Mat<T>& testData,
260 LabelsType& trainLabel,
261 LabelsType& testLabel,
262 const double testRatio,
263 const bool shuffleData =
true)
268 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
269 input.n_cols - 1, input.n_cols));
270 SplitHelper(input, trainData, testData, testRatio, order);
271 SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
275 SplitHelper(input, trainData, testData, testRatio);
276 SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
304 void Split(
const arma::Mat<T>& input,
305 arma::Mat<T>& trainData,
306 arma::Mat<T>& testData,
307 const double testRatio,
308 const bool shuffleData =
true)
312 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
313 input.n_cols - 1, input.n_cols));
314 SplitHelper(input, trainData, testData, testRatio, order);
318 SplitHelper(input, trainData, testData, testRatio);
350 template<
typename T,
typename LabelsType,
352 std::tuple<arma::Mat<T>, arma::Mat<T>, LabelsType, LabelsType>
354 const LabelsType& inputLabel,
355 const double testRatio,
356 const bool shuffleData =
true,
357 const bool stratifyData =
false)
359 arma::Mat<T> trainData;
360 arma::Mat<T> testData;
361 LabelsType trainLabel;
362 LabelsType testLabel;
367 testLabel, testRatio, shuffleData);
371 Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
372 testRatio, shuffleData);
375 return std::make_tuple(std::move(trainData),
377 std::move(trainLabel),
378 std::move(testLabel));
400 std::tuple<arma::Mat<T>, arma::Mat<T>>
402 const double testRatio,
403 const bool shuffleData =
true)
405 arma::Mat<T> trainData;
406 arma::Mat<T> testData;
407 Split(input, trainData, testData, testRatio, shuffleData);
409 return std::make_tuple(std::move(trainData),
410 std::move(testData));
447 template <
typename FieldType,
typename T,
449 arma::is_Col<typename FieldType::object_type>::value ||
450 arma::is_Mat_only<typename FieldType::object_type>::value>>
452 const arma::field<T>& inputLabel,
453 FieldType& trainData,
454 arma::field<T>& trainLabel,
456 arma::field<T>& testLabel,
457 const double testRatio,
458 const bool shuffleData =
true)
463 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
464 input.n_cols - 1, input.n_cols));
465 SplitHelper(input, trainData, testData, testRatio, order);
466 SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
470 SplitHelper(input, trainData, testData, testRatio);
471 SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
503 template <
class FieldType,
505 arma::is_Col<typename FieldType::object_type>::value ||
506 arma::is_Mat_only<typename FieldType::object_type>::value>>
508 FieldType& trainData,
510 const double testRatio,
511 const bool shuffleData =
true)
515 arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
516 input.n_cols - 1, input.n_cols));
517 SplitHelper(input, trainData, testData, testRatio, order);
521 SplitHelper(input, trainData, testData, testRatio);
552 template <
class FieldType,
typename T,
554 arma::is_Col<typename FieldType::object_type>::value ||
555 arma::is_Mat_only<typename FieldType::object_type>::value>>
556 std::tuple<FieldType, FieldType, arma::field<T>, arma::field<T>>
558 const arma::field<T>& inputLabel,
559 const double testRatio,
560 const bool shuffleData =
true)
564 arma::field<T> trainLabel;
565 arma::field<T> testLabel;
567 Split(input, inputLabel, trainData, trainLabel, testData, testLabel,
568 testRatio, shuffleData);
570 return std::make_tuple(std::move(trainData),
572 std::move(trainLabel),
573 std::move(testLabel));
599 template <
class FieldType,
601 arma::is_Col<typename FieldType::object_type>::value ||
602 arma::is_Mat_only<typename FieldType::object_type>::value>>
603 std::tuple<FieldType, FieldType>
605 const double testRatio,
606 const bool shuffleData =
true)
610 Split(input, trainData, testData, testRatio, shuffleData);
612 return std::make_tuple(std::move(trainData),
613 std::move(testData));
typename enable_if< B, T >::type enable_if_t
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void CheckSameSizes(const DataType &data, const LabelsType &label, const std::string &callerDescription, const std::string &addInfo="labels")
Check for if the given data points & labels have same size.
void Split(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
void StratifiedSplit(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, stratify into a training set and test set.
void SplitHelper(const InputType &input, InputType &train, InputType &test, const double testRatio, const arma::uvec &order=arma::uvec())
This helper function splits any input data into training and testing parts.