split_data.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_DATA_SPLIT_DATA_HPP
14 #define MLPACK_CORE_DATA_SPLIT_DATA_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace data {
20 
26 template<typename InputType>
27 void SplitHelper(const InputType& input,
28  InputType& train,
29  InputType& test,
30  const double testRatio,
31  const arma::uvec& order = arma::uvec())
32 {
33  const size_t testSize = static_cast<size_t>(input.n_cols * testRatio);
34  const size_t trainSize = input.n_cols - testSize;
35 
36  // Initialising the sizes of outputs if not already initialized.
37  train.set_size(input.n_rows, trainSize);
38  test.set_size(input.n_rows, testSize);
39 
40  // Shuffling and spliting simultaneously.
41  if (!order.is_empty())
42  {
43  if (trainSize > 0)
44  {
45  for (size_t i = 0; i < trainSize; ++i)
46  train.col(i) = input.col(order(i));
47  }
48  if (trainSize < input.n_cols)
49  {
50  for (size_t i = trainSize; i < input.n_cols; ++i)
51  test.col(i - trainSize) = input.col(order(i));
52  }
53  }
54  // Spliting only.
55  else
56  {
57  if (trainSize > 0)
58  train = input.cols(0, trainSize - 1);
59 
60  if (trainSize < input.n_cols)
61  test = input.cols(trainSize, input.n_cols - 1);
62  }
63 }
64 
101 template<typename T, typename LabelsType,
103 void StratifiedSplit(const arma::Mat<T>& input,
104  const LabelsType& inputLabel,
105  arma::Mat<T>& trainData,
106  arma::Mat<T>& testData,
107  LabelsType& trainLabel,
108  LabelsType& testLabel,
109  const double testRatio,
110  const bool shuffleData = true)
111 {
145  const bool typeCheck = (arma::is_Row<LabelsType>::value)
146  || (arma::is_Col<LabelsType>::value);
147  if (!typeCheck)
148  throw std::runtime_error("data::Split(): when stratified sampling is done, "
149  "labels must have type `arma::Row<>`!");
150  util::CheckSameSizes(input, inputLabel, "data::Split()");
151  size_t trainIdx = 0;
152  size_t testIdx = 0;
153  size_t trainSize = 0;
154  size_t testSize = 0;
155  arma::uvec labelCounts;
156  arma::uvec testLabelCounts;
157  typename LabelsType::elem_type maxLabel = inputLabel.max();
158 
159  labelCounts.zeros(maxLabel+1);
160  testLabelCounts.zeros(maxLabel+1);
161 
162  for (typename LabelsType::elem_type label : inputLabel)
163  ++labelCounts[label];
164 
165  for (arma::uword labelCount : labelCounts)
166  {
167  testSize += floor(labelCount * testRatio);
168  trainSize += labelCount - floor(labelCount * testRatio);
169  }
170 
171  trainData.set_size(input.n_rows, trainSize);
172  testData.set_size(input.n_rows, testSize);
173  trainLabel.set_size(inputLabel.n_rows, trainSize);
174  testLabel.set_size(inputLabel.n_rows, testSize);
175 
176  if (shuffleData)
177  {
178  arma::uvec order = arma::shuffle(
179  arma::linspace<arma::uvec>(0, input.n_cols - 1, input.n_cols));
180 
181  for (arma::uword i : order)
182  {
183  typename LabelsType::elem_type label = inputLabel[i];
184  if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
185  {
186  testLabelCounts[label] += 1;
187  testData.col(testIdx) = input.col(i);
188  testLabel[testIdx] = inputLabel[i];
189  testIdx += 1;
190  }
191  else
192  {
193  trainData.col(trainIdx) = input.col(i);
194  trainLabel[trainIdx] = inputLabel[i];
195  trainIdx += 1;
196  }
197  }
198  }
199  else
200  {
201  for (arma::uword i = 0; i < input.n_cols; i++)
202  {
203  typename LabelsType::elem_type label = inputLabel[i];
204  if (testLabelCounts[label] < floor(labelCounts[label] * testRatio))
205  {
206  testLabelCounts[label] += 1;
207  testData.col(testIdx) = input.col(i);
208  testLabel[testIdx] = inputLabel[i];
209  testIdx += 1;
210  }
211  else
212  {
213  trainData.col(trainIdx) = input.col(i);
214  trainLabel[trainIdx] = inputLabel[i];
215  trainIdx += 1;
216  }
217  }
218  }
219 }
220 
254 template<typename T, typename LabelsType,
256 void Split(const arma::Mat<T>& input,
257  const LabelsType& inputLabel,
258  arma::Mat<T>& trainData,
259  arma::Mat<T>& testData,
260  LabelsType& trainLabel,
261  LabelsType& testLabel,
262  const double testRatio,
263  const bool shuffleData = true)
264 {
265  util::CheckSameSizes(input, inputLabel, "data::Split()");
266  if (shuffleData)
267  {
268  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
269  input.n_cols - 1, input.n_cols));
270  SplitHelper(input, trainData, testData, testRatio, order);
271  SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
272  }
273  else
274  {
275  SplitHelper(input, trainData, testData, testRatio);
276  SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
277  }
278 }
279 
303 template<typename T>
304 void Split(const arma::Mat<T>& input,
305  arma::Mat<T>& trainData,
306  arma::Mat<T>& testData,
307  const double testRatio,
308  const bool shuffleData = true)
309 {
310  if (shuffleData)
311  {
312  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
313  input.n_cols - 1, input.n_cols));
314  SplitHelper(input, trainData, testData, testRatio, order);
315  }
316  else
317  {
318  SplitHelper(input, trainData, testData, testRatio);
319  }
320 }
321 
350 template<typename T, typename LabelsType,
352 std::tuple<arma::Mat<T>, arma::Mat<T>, LabelsType, LabelsType>
353 Split(const arma::Mat<T>& input,
354  const LabelsType& inputLabel,
355  const double testRatio,
356  const bool shuffleData = true,
357  const bool stratifyData = false)
358 {
359  arma::Mat<T> trainData;
360  arma::Mat<T> testData;
361  LabelsType trainLabel;
362  LabelsType testLabel;
363 
364  if (stratifyData)
365  {
366  StratifiedSplit(input, inputLabel, trainData, testData, trainLabel,
367  testLabel, testRatio, shuffleData);
368  }
369  else
370  {
371  Split(input, inputLabel, trainData, testData, trainLabel, testLabel,
372  testRatio, shuffleData);
373  }
374 
375  return std::make_tuple(std::move(trainData),
376  std::move(testData),
377  std::move(trainLabel),
378  std::move(testLabel));
379 }
380 
399 template<typename T>
400 std::tuple<arma::Mat<T>, arma::Mat<T>>
401 Split(const arma::Mat<T>& input,
402  const double testRatio,
403  const bool shuffleData = true)
404 {
405  arma::Mat<T> trainData;
406  arma::Mat<T> testData;
407  Split(input, trainData, testData, testRatio, shuffleData);
408 
409  return std::make_tuple(std::move(trainData),
410  std::move(testData));
411 }
412 
447 template <typename FieldType, typename T,
448  typename = std::enable_if_t<
449  arma::is_Col<typename FieldType::object_type>::value ||
450  arma::is_Mat_only<typename FieldType::object_type>::value>>
451 void Split(const FieldType& input,
452  const arma::field<T>& inputLabel,
453  FieldType& trainData,
454  arma::field<T>& trainLabel,
455  FieldType& testData,
456  arma::field<T>& testLabel,
457  const double testRatio,
458  const bool shuffleData = true)
459 {
460  util::CheckSameSizes(input, inputLabel, "data::Split()");
461  if (shuffleData)
462  {
463  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
464  input.n_cols - 1, input.n_cols));
465  SplitHelper(input, trainData, testData, testRatio, order);
466  SplitHelper(inputLabel, trainLabel, testLabel, testRatio, order);
467  }
468  else
469  {
470  SplitHelper(input, trainData, testData, testRatio);
471  SplitHelper(inputLabel, trainLabel, testLabel, testRatio);
472  }
473 }
474 
503 template <class FieldType,
504  class = std::enable_if_t<
505  arma::is_Col<typename FieldType::object_type>::value ||
506  arma::is_Mat_only<typename FieldType::object_type>::value>>
507 void Split(const FieldType& input,
508  FieldType& trainData,
509  FieldType& testData,
510  const double testRatio,
511  const bool shuffleData = true)
512 {
513  if (shuffleData)
514  {
515  arma::uvec order = arma::shuffle(arma::linspace<arma::uvec>(0,
516  input.n_cols - 1, input.n_cols));
517  SplitHelper(input, trainData, testData, testRatio, order);
518  }
519  else
520  {
521  SplitHelper(input, trainData, testData, testRatio);
522  }
523 }
524 
552 template <class FieldType, typename T,
553  class = std::enable_if_t<
554  arma::is_Col<typename FieldType::object_type>::value ||
555  arma::is_Mat_only<typename FieldType::object_type>::value>>
556 std::tuple<FieldType, FieldType, arma::field<T>, arma::field<T>>
557 Split(const FieldType& input,
558  const arma::field<T>& inputLabel,
559  const double testRatio,
560  const bool shuffleData = true)
561 {
562  FieldType trainData;
563  FieldType testData;
564  arma::field<T> trainLabel;
565  arma::field<T> testLabel;
566 
567  Split(input, inputLabel, trainData, trainLabel, testData, testLabel,
568  testRatio, shuffleData);
569 
570  return std::make_tuple(std::move(trainData),
571  std::move(testData),
572  std::move(trainLabel),
573  std::move(testLabel));
574 }
575 
599 template <class FieldType,
600  class = std::enable_if_t<
601  arma::is_Col<typename FieldType::object_type>::value ||
602  arma::is_Mat_only<typename FieldType::object_type>::value>>
603 std::tuple<FieldType, FieldType>
604 Split(const FieldType& input,
605  const double testRatio,
606  const bool shuffleData = true)
607 {
608  FieldType trainData;
609  FieldType testData;
610  Split(input, trainData, testData, testRatio, shuffleData);
611 
612  return std::make_tuple(std::move(trainData),
613  std::move(testData));
614 }
615 
616 } // namespace data
617 } // namespace mlpack
618 
619 #endif
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void CheckSameSizes(const DataType &data, const LabelsType &label, const std::string &callerDescription, const std::string &addInfo="labels")
Check for if the given data points & labels have same size.
Definition: size_checks.hpp:31
void Split(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, split into a training set and test set.
Definition: split_data.hpp:256
void StratifiedSplit(const arma::Mat< T > &input, const LabelsType &inputLabel, arma::Mat< T > &trainData, arma::Mat< T > &testData, LabelsType &trainLabel, LabelsType &testLabel, const double testRatio, const bool shuffleData=true)
Given an input dataset and labels, stratify into a training set and test set.
Definition: split_data.hpp:103
void SplitHelper(const InputType &input, InputType &train, InputType &test, const double testRatio, const arma::uvec &order=arma::uvec())
This helper function splits any input data into training and testing parts.
Definition: split_data.hpp:27