get_param.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_BINDINGS_CLI_GET_PARAM_HPP
13 #define MLPACK_BINDINGS_CLI_GET_PARAM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "parameter_type.hpp"
17 
18 namespace mlpack {
19 namespace bindings {
20 namespace cli {
21 
28 template<typename T>
30  util::ParamData& d,
31  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
32  const typename std::enable_if<!data::HasSerialize<T>::value>::type* = 0,
33  const typename std::enable_if<!std::is_same<T,
34  std::tuple<mlpack::data::DatasetInfo, arma::mat>>::value>::type* = 0)
35 {
36  // No mapping is needed, so just cast it directly.
37  return *boost::any_cast<T>(&d.value);
38 }
39 
45 template<typename T>
47  util::ParamData& d,
48  const typename std::enable_if<arma::is_arma_type<T>::value>::type* = 0)
49 {
50  // If the matrix is an input matrix, we have to load the matrix. 'value'
51  // contains the filename. It's possible we could load empty matrices many
52  // times, but I am not bothered by that---it shouldn't be something that
53  // happens.
54  typedef std::tuple<T, typename ParameterType<T>::type> TupleType;
55  TupleType& tuple = *boost::any_cast<TupleType>(&d.value);
56  const std::string& value = std::get<0>(std::get<1>(tuple));
57  T& matrix = std::get<0>(tuple);
58  size_t& n_rows = std::get<1>(std::get<1>(tuple));
59  size_t& n_cols = std::get<2>(std::get<1>(tuple));
60  if (d.input && !d.loaded)
61  {
62  // Call correct data::Load() function.
63  if (arma::is_Row<T>::value || arma::is_Col<T>::value)
64  data::Load(value, matrix, true);
65  else
66  data::Load(value, matrix, true, !d.noTranspose);
67  n_rows = matrix.n_rows;
68  n_cols = matrix.n_cols;
69  d.loaded = true;
70  }
71 
72  return matrix;
73 }
74 
80 template<typename T>
82  util::ParamData& d,
83  const typename std::enable_if<std::is_same<T,
84  std::tuple<mlpack::data::DatasetInfo, arma::mat>>::value>::type* = 0)
85 {
86  // If this is an input parameter, we need to load both the matrix and the
87  // dataset info.
88  typedef std::tuple<T, std::tuple<std::string, size_t, size_t>> TupleType;
89  TupleType* tuple = boost::any_cast<TupleType>(&d.value);
90  const std::string& value = std::get<0>(std::get<1>(*tuple));
91  T& t = std::get<0>(*tuple);
92  size_t& n_rows = std::get<1>(std::get<1>(*tuple));
93  size_t& n_cols = std::get<2>(std::get<1>(*tuple));
94  if (d.input && !d.loaded)
95  {
96  data::Load(value, std::get<1>(t), std::get<0>(t), true, !d.noTranspose);
97  n_rows = std::get<1>(t).n_rows;
98  n_cols = std::get<1>(t).n_cols;
99  d.loaded = true;
100  }
101 
102  return t;
103 }
104 
110 template<typename T>
112  util::ParamData& d,
113  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
114  const typename std::enable_if<data::HasSerialize<T>::value>::type* = 0)
115 {
116  // If the model is an input model, we have to load it from file. 'value'
117  // contains the filename.
118  typedef std::tuple<T*, std::string> TupleType;
119  TupleType* tuple = boost::any_cast<TupleType>(&d.value);
120  const std::string& value = std::get<1>(*tuple);
121  if (d.input && !d.loaded)
122  {
123  T* model = new T();
124  data::Load(value, "model", *model, true);
125  d.loaded = true;
126  std::get<0>(*tuple) = model;
127  }
128  return std::get<0>(*tuple);
129 }
130 
139 template<typename T>
140 void GetParam(util::ParamData& d, const void* /* input */, void* output)
141 {
142  // Cast to the correct type.
143  *((T**) output) = &GetParam<typename std::remove_pointer<T>::type>(d);
144 }
145 
146 } // namespace cli
147 } // namespace bindings
148 } // namespace mlpack
149 
150 #endif
boost::any value
The actual value that is held.
Definition: param_data.hpp:79
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
bool input
True if this option is an input option (otherwise, it is output).
Definition: param_data.hpp:73
T & GetParam(util::ParamData &d, const typename std::enable_if<!arma::is_arma_type< T >::value >::type *=0, const typename std::enable_if<!data::HasSerialize< T >::value >::type *=0, const typename std::enable_if<!std::is_same< T, std::tuple< mlpack::data::DatasetInfo, arma::mat >>::value >::type *=0)
This overload is called when nothing special needs to happen to the name of the parameter.
Definition: get_param.hpp:29
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:52
bool loaded
If this is an input parameter that needs extra loading, this indicates whether or not it has been loa...
Definition: param_data.hpp:76
bool Load(const std::string &filename, arma::Mat< eT > &matrix, const bool fatal=false, const bool transpose=true, const arma::file_type inputLoadType=arma::auto_detect)
Loads a matrix from file, guessing the filetype from the extension.
bool noTranspose
True if this is a matrix that should not be transposed.
Definition: param_data.hpp:69