load_csv.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_LOAD_CSV_HPP
13 #define MLPACK_CORE_DATA_LOAD_CSV_HPP
14 
15 #include <boost/spirit/include/qi.hpp>
16 #include <boost/algorithm/string/trim.hpp>
17 
18 #include <mlpack/core.hpp>
19 #include <mlpack/core/util/log.hpp>
20 
21 #include <set>
22 #include <string>
23 
24 #include "extension.hpp"
25 #include "format.hpp"
26 #include "dataset_mapper.hpp"
27 
28 namespace mlpack {
29 namespace data {
30 
36 class LoadCSV
37 {
38  public:
43  LoadCSV(const std::string& file);
44 
54  template<typename T, typename PolicyType>
55  void Load(arma::Mat<T> &inout,
57  const bool transpose = true)
58  {
59  CheckOpen();
60 
61  if (transpose)
62  TransposeParse(inout, infoSet);
63  else
64  NonTransposeParse(inout, infoSet);
65  }
66 
77  template<typename T, typename MapPolicy>
78  void GetMatrixSize(size_t& rows, size_t& cols, DatasetMapper<MapPolicy>& info)
79  {
80  using namespace boost::spirit;
81 
82  // Take a pass through the file. If the DatasetMapper policy requires it,
83  // we will pass everything string through MapString(). This might be useful
84  // if, e.g., the MapPolicy needs to find which dimensions are numeric or
85  // categorical.
86 
87  // Reset to the start of the file.
88  inFile.clear();
89  inFile.seekg(0, std::ios::beg);
90  rows = 0;
91  cols = 0;
92 
93  // First, count the number of rows in the file (this is the dimensionality).
94  std::string line;
95  while (std::getline(inFile, line))
96  {
97  ++rows;
98  }
99 
100  // Reset the DatasetInfo object, if needed.
101  if (info.Dimensionality() == 0)
102  {
103  info.SetDimensionality(rows);
104  }
105  else if (info.Dimensionality() != rows)
106  {
107  std::ostringstream oss;
108  oss << "data::LoadCSV(): given DatasetInfo has dimensionality "
109  << info.Dimensionality() << ", but data has dimensionality "
110  << rows;
111  throw std::invalid_argument(oss.str());
112  }
113 
114  // Now, jump back to the beginning of the file.
115  inFile.clear();
116  inFile.seekg(0, std::ios::beg);
117  rows = 0;
118 
119  while (std::getline(inFile, line))
120  {
121  ++rows;
122  // Remove whitespace from either side.
123  boost::trim(line);
124 
125  if (rows == 1)
126  {
127  // Extract the number of columns.
128  auto findColSize = [&cols](iter_type) { ++cols; };
129  qi::parse(line.begin(), line.end(),
130  stringRule[findColSize] % delimiterRule);
131  }
132 
133  // I guess this is technically a second pass, but that's ok... still the
134  // same idea...
135  if (MapPolicy::NeedsFirstPass)
136  {
137  // In this case we must pass everything we parse to the MapPolicy.
138  auto firstPassMap = [&](const iter_type& iter)
139  {
140  std::string str(iter.begin(), iter.end());
141  boost::trim(str);
142 
143  info.template MapFirstPass<T>(std::move(str), rows - 1);
144  };
145 
146  // Now parse the line.
147  qi::parse(line.begin(), line.end(),
148  stringRule[firstPassMap] % delimiterRule);
149  }
150  }
151  }
152 
163  template<typename T, typename MapPolicy>
164  void GetTransposeMatrixSize(size_t& rows,
165  size_t& cols,
167  {
168  using namespace boost::spirit;
169 
170  // Take a pass through the file. If the DatasetMapper policy requires it,
171  // we will pass everything string through MapString(). This might be useful
172  // if, e.g., the MapPolicy needs to find which dimensions are numeric or
173  // categorical.
174 
175  // Reset to the start of the file.
176  inFile.clear();
177  inFile.seekg(0, std::ios::beg);
178  rows = 0;
179  cols = 0;
180 
181  std::string line;
182  while (std::getline(inFile, line))
183  {
184  ++cols;
185  // Remove whitespace from either side.
186  boost::trim(line);
187 
188  if (cols == 1)
189  {
190  // Extract the number of dimensions.
191  auto findRowSize = [&rows](iter_type) { ++rows; };
192  qi::parse(line.begin(), line.end(),
193  stringRule[findRowSize] % delimiterRule);
194 
195  // Reset the DatasetInfo object, if needed.
196  if (info.Dimensionality() == 0)
197  {
198  info.SetDimensionality(rows);
199  }
200  else if (info.Dimensionality() != rows)
201  {
202  std::ostringstream oss;
203  oss << "data::LoadCSV(): given DatasetInfo has dimensionality "
204  << info.Dimensionality() << ", but data has dimensionality "
205  << rows;
206  throw std::invalid_argument(oss.str());
207  }
208  }
209 
210  // If we need to do a first pass for the DatasetMapper, do it.
211  if (MapPolicy::NeedsFirstPass)
212  {
213  size_t dim = 0;
214 
215  // In this case we must pass everything we parse to the MapPolicy.
216  auto firstPassMap = [&](const iter_type& iter)
217  {
218  std::string str(iter.begin(), iter.end());
219  boost::trim(str);
220 
221  info.template MapFirstPass<T>(std::move(str), dim++);
222  };
223 
224  // Now parse the line.
225  qi::parse(line.begin(), line.end(),
226  stringRule[firstPassMap] % delimiterRule);
227  }
228  }
229  }
230 
231  private:
232  using iter_type = boost::iterator_range<std::string::iterator>;
233 
238  void CheckOpen();
239 
246  template<typename T, typename PolicyType>
247  void NonTransposeParse(arma::Mat<T>& inout,
248  DatasetMapper<PolicyType>& infoSet)
249  {
250  using namespace boost::spirit;
251 
252  // Get the size of the matrix.
253  size_t rows, cols;
254  GetMatrixSize<T>(rows, cols, infoSet);
255 
256  // Set up output matrix.
257  inout.set_size(rows, cols);
258  size_t row = 0;
259  size_t col = 0;
260 
261  // Reset file position.
262  std::string line;
263  inFile.clear();
264  inFile.seekg(0, std::ios::beg);
265 
266  auto setCharClass = [&](iter_type const &iter)
267  {
268  std::string str(iter.begin(), iter.end());
269  if (str == "\t")
270  {
271  str.clear();
272  }
273  boost::trim(str);
274 
275  inout(row, col++) = infoSet.template MapString<T>(std::move(str), row);
276  };
277 
278  while (std::getline(inFile, line))
279  {
280  // Remove whitespace from either side.
281  boost::trim(line);
282 
283  // Parse the numbers from a line (ex: 1,2,3,4); if the parser finds a
284  // number it will execute the setNum function.
285  const bool canParse = qi::parse(line.begin(), line.end(),
286  stringRule[setCharClass] % delimiterRule);
287 
288  // Make sure we got the right number of rows.
289  if (col != cols)
290  {
291  std::ostringstream oss;
292  oss << "LoadCSV::NonTransposeParse(): wrong number of dimensions ("
293  << col << ") on line " << row << "; should be " << cols
294  << " dimensions.";
295  throw std::runtime_error(oss.str());
296  }
297 
298  if (!canParse)
299  {
300  std::ostringstream oss;
301  oss << "LoadCSV::NonTransposeParse(): parsing error on line " << col
302  << "!";
303  throw std::runtime_error(oss.str());
304  }
305 
306  ++row; col = 0;
307  }
308  }
309 
316  template<typename T, typename PolicyType>
317  void TransposeParse(arma::Mat<T>& inout, DatasetMapper<PolicyType>& infoSet)
318  {
319  using namespace boost::spirit;
320 
321  // Get matrix size. This also initializes infoSet correctly.
322  size_t rows, cols;
323  GetTransposeMatrixSize<T>(rows, cols, infoSet);
324 
325  // Set the matrix size.
326  inout.set_size(rows, cols);
327 
328  // Initialize auxiliary variables.
329  size_t row = 0;
330  size_t col = 0;
331  std::string line;
332  inFile.clear();
333  inFile.seekg(0, std::ios::beg);
334 
339  auto parseString = [&](iter_type const &iter)
340  {
341  // All parsed values must be mapped.
342  std::string str(iter.begin(), iter.end());
343  boost::trim(str);
344 
345  inout(row, col) = infoSet.template MapString<T>(std::move(str), row);
346  ++row;
347  };
348 
349  while (std::getline(inFile, line))
350  {
351  // Remove whitespace from either side.
352  boost::trim(line);
353 
354  // Reset the row we are looking at. (Remember this is transposed.)
355  row = 0;
356 
357  // Now use boost::spirit to parse the characters of the line;
358  // parseString() will be called when a token is detected.
359  const bool canParse = qi::parse(line.begin(), line.end(),
360  stringRule[parseString] % delimiterRule);
361 
362  // Make sure we got the right number of rows.
363  if (row != rows)
364  {
365  std::ostringstream oss;
366  oss << "LoadCSV::TransposeParse(): wrong number of dimensions (" << row
367  << ") on line " << col << "; should be " << rows << " dimensions.";
368  throw std::runtime_error(oss.str());
369  }
370 
371  if (!canParse)
372  {
373  std::ostringstream oss;
374  oss << "LoadCSV::TransposeParse(): parsing error on line " << col
375  << "!";
376  throw std::runtime_error(oss.str());
377  }
378 
379  // Increment the column index.
380  ++col;
381  }
382  }
383 
385  boost::spirit::qi::rule<std::string::iterator, iter_type()> stringRule;
387  boost::spirit::qi::rule<std::string::iterator, iter_type()> delimiterRule;
388 
390  std::string extension;
392  std::string filename;
394  std::ifstream inFile;
395 };
396 
397 } // namespace data
398 } // namespace mlpack
399 
400 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
Load the csv file.This class use boost::spirit to implement the parser, please refer to following lin...
Definition: load_csv.hpp:36
void Load(arma::Mat< T > &inout, DatasetMapper< PolicyType > &infoSet, const bool transpose=true)
Load the file into the given matrix with the given DatasetMapper object.
Definition: load_csv.hpp:55
Linear algebra utility functions, generally performed on matrices or vectors.
void GetTransposeMatrixSize(size_t &rows, size_t &cols, DatasetMapper< MapPolicy > &info)
Peek at the file to determine the number of rows and columns in the matrix, assuming a transposed mat...
Definition: load_csv.hpp:164
LoadCSV(const std::string &file)
Construct the LoadCSV object on the given file.
void GetMatrixSize(size_t &rows, size_t &cols, DatasetMapper< MapPolicy > &info)
Peek at the file to determine the number of rows and columns in the matrix, assuming a non-transposed...
Definition: load_csv.hpp:78
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void SetDimensionality(const size_t dimensionality)
Set the dimensionality of an existing DatasetMapper object.
size_t Dimensionality() const
Get the dimensionality of the DatasetMapper object (that is, how many dimensions it has information f...