12 #ifndef MLPACK_CORE_MATH_SHUFFLE_DATA_HPP 13 #define MLPACK_CORE_MATH_SHUFFLE_DATA_HPP 27 template<
typename MatType,
typename LabelsType>
29 const LabelsType& inputLabels,
30 MatType& outputPoints,
31 LabelsType& outputLabels,
36 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
37 inputPoints.n_cols - 1, inputPoints.n_cols));
39 outputPoints = inputPoints.cols(ordering);
40 outputLabels = inputLabels.cols(ordering);
50 template<
typename MatType,
typename LabelsType>
52 const LabelsType& inputLabels,
53 MatType& outputPoints,
54 LabelsType& outputLabels,
59 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
60 inputPoints.n_cols - 1, inputPoints.n_cols));
63 arma::umat locations(2, inputPoints.n_nonzero);
64 arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
65 typename MatType::const_iterator it = inputPoints.begin();
67 while (it != inputPoints.end())
69 locations(0, index) = it.row();
70 locations(1, index) = ordering[it.col()];
71 values(index) = (*it);
76 if (&inputPoints == &outputPoints || &inputLabels == &outputLabels)
78 MatType newOutputPoints(locations, values, inputPoints.n_rows,
79 inputPoints.n_cols,
true);
80 LabelsType newOutputLabels(inputLabels.n_elem);
81 newOutputLabels.cols(ordering) = inputLabels;
83 outputPoints = std::move(newOutputPoints);
84 outputLabels = std::move(newOutputLabels);
88 outputPoints = MatType(locations, values, inputPoints.n_rows,
89 inputPoints.n_cols,
true);
90 outputLabels.set_size(inputLabels.n_elem);
91 outputLabels.cols(ordering) = inputLabels;
102 template<
typename MatType,
typename LabelsType>
104 const LabelsType& inputLabels,
105 MatType& outputPoints,
106 LabelsType& outputLabels,
112 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
113 inputPoints.n_cols - 1, inputPoints.n_cols));
117 MatType* outputPointsPtr = &outputPoints;
118 LabelsType* outputLabelsPtr = &outputLabels;
119 if (&inputPoints == &outputPoints)
120 outputPointsPtr =
new MatType();
121 if (&inputLabels == &outputLabels)
122 outputLabelsPtr =
new LabelsType();
124 outputPointsPtr->set_size(inputPoints.n_rows, inputPoints.n_cols,
125 inputPoints.n_slices);
126 outputLabelsPtr->set_size(inputLabels.n_rows, inputLabels.n_cols,
127 inputLabels.n_slices);
128 for (
size_t i = 0; i < ordering.n_elem; ++i)
130 outputPointsPtr->tube(0, ordering[i], outputPointsPtr->n_rows - 1,
131 ordering[i]) = inputPoints.tube(0, i, inputPoints.n_rows - 1, i);
132 outputLabelsPtr->tube(0, ordering[i], outputLabelsPtr->n_rows - 1,
133 ordering[i]) = inputLabels.tube(0, i, inputLabels.n_rows - 1, i);
137 if (&inputPoints == &outputPoints)
139 outputPoints = std::move(*outputPointsPtr);
140 delete outputPointsPtr;
143 if (&inputLabels == &outputLabels)
145 outputLabels = std::move(*outputLabelsPtr);
146 delete outputLabelsPtr;
159 template<
typename MatType,
typename LabelsType,
typename WeightsType>
161 const LabelsType& inputLabels,
162 const WeightsType& inputWeights,
163 MatType& outputPoints,
164 LabelsType& outputLabels,
165 WeightsType& outputWeights,
170 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
171 inputPoints.n_cols - 1, inputPoints.n_cols));
173 outputPoints = inputPoints.cols(ordering);
174 outputLabels = inputLabels.cols(ordering);
175 outputWeights = inputWeights.cols(ordering);
187 template<
typename MatType,
typename LabelsType,
typename WeightsType>
189 const LabelsType& inputLabels,
190 const WeightsType& inputWeights,
191 MatType& outputPoints,
192 LabelsType& outputLabels,
193 WeightsType& outputWeights,
198 arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
199 inputPoints.n_cols - 1, inputPoints.n_cols));
202 arma::umat locations(2, inputPoints.n_nonzero);
203 arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
204 typename MatType::const_iterator it = inputPoints.begin();
206 while (it != inputPoints.end())
208 locations(0, index) = it.row();
209 locations(1, index) = ordering[it.col()];
210 values(index) = (*it);
215 if (&inputPoints == &outputPoints || &inputLabels == &outputLabels ||
216 &inputWeights == &outputWeights)
218 MatType newOutputPoints(locations, values, inputPoints.n_rows,
219 inputPoints.n_cols,
true);
220 LabelsType newOutputLabels(inputLabels.n_elem);
221 WeightsType newOutputWeights(inputWeights.n_elem);
222 newOutputLabels.cols(ordering) = inputLabels;
223 newOutputWeights.cols(ordering) = inputWeights;
225 outputPoints = std::move(newOutputPoints);
226 outputLabels = std::move(newOutputLabels);
227 outputWeights = std::move(newOutputWeights);
231 outputPoints = MatType(locations, values, inputPoints.n_rows,
232 inputPoints.n_cols,
true);
233 outputLabels.set_size(inputLabels.n_elem);
234 outputLabels.cols(ordering) = inputLabels;
235 outputWeights.set_size(inputWeights.n_elem);
236 outputWeights.cols(ordering) = inputWeights;
typename enable_if< B, T >::type enable_if_t
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).