print_class_defn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_BINDINGS_PYTHON_PRINT_CLASS_DEFN_HPP
13 #define MLPACK_BINDINGS_PYTHON_PRINT_CLASS_DEFN_HPP
14 
15 #include "strip_type.hpp"
16 
17 namespace mlpack {
18 namespace bindings {
19 namespace python {
20 
25 template<typename T>
27  util::ParamData& /* d */,
28  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
29  const typename std::enable_if<!data::HasSerialize<T>::value>::type* = 0)
30 {
31  // Do nothing.
32 }
33 
37 template<typename T>
39  util::ParamData& /* d */,
40  const typename std::enable_if<arma::is_arma_type<T>::value>::type* = 0)
41 {
42  // Do nothing.
43 }
44 
48 template<typename T>
50  util::ParamData& d,
51  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
52  const typename std::enable_if<data::HasSerialize<T>::value>::type* = 0)
53 {
54  // First, we have to parse the type. If we have something like, e.g.,
55  // 'LogisticRegression<>', we must convert this to 'LogisticRegression[].'
56  std::string strippedType, printedType, defaultsType;
57  StripType(d.cppType, strippedType, printedType, defaultsType);
58 
99  std::cout << "cdef class " << strippedType << "Type:" << std::endl;
100  std::cout << " cdef " << printedType << "* modelptr" << std::endl;
101  std::cout << " cdef public dict scrubbed_params" << std::endl;
102  std::cout << std::endl;
103  std::cout << " def __cinit__(self):" << std::endl;
104  std::cout << " self.modelptr = new " << printedType << "()" << std::endl;
105  std::cout << " self.scrubbed_params = dict()" << std::endl;
106  std::cout << std::endl;
107  std::cout << " def __dealloc__(self):" << std::endl;
108  std::cout << " del self.modelptr" << std::endl;
109  std::cout << std::endl;
110  std::cout << " def __getstate__(self):" << std::endl;
111  std::cout << " return SerializeOut(self.modelptr, \"" << printedType
112  << "\")" << std::endl;
113  std::cout << std::endl;
114  std::cout << " def __setstate__(self, state):" << std::endl;
115  std::cout << " SerializeIn(self.modelptr, state, \"" << printedType
116  << "\")" << std::endl;
117  std::cout << std::endl;
118  std::cout << " def __reduce_ex__(self, version):" << std::endl;
119  std::cout << " return (self.__class__, (), self.__getstate__())"
120  << std::endl;
121  std::cout << std::endl;
122  std::cout << " def _get_cpp_params(self):" << std::endl;
123  std::cout << " return SerializeOutJSON(self.modelptr, \"" << printedType
124  << "\")" << std::endl;
125  std::cout << std::endl;
126  std::cout << " def _set_cpp_params(self, state):" << std::endl;
127  std::cout << " SerializeInJSON(self.modelptr, state, \"" << printedType
128  << "\")" << std::endl;
129  std::cout << std::endl;
130  std::cout << " def get_cpp_params(self, return_str=False):" << std::endl;
131  std::cout << " params = self._get_cpp_params()" << std::endl;
132  std::cout << " return process_params_out(self, params, "
133  << "return_str=return_str)" << std::endl;
134  std::cout << std::endl;
135  std::cout << " def set_cpp_params(self, params_dic):" << std::endl;
136  std::cout << " params_str = process_params_in(self, params_dic)"
137  << std::endl;
138  std::cout << " self._set_cpp_params(params_str.encode(\"utf-8\"))"
139  << std::endl;
140  std::cout << std::endl;
141 }
142 
151 template<typename T>
153  const void* /* input */,
154  void* /* output */)
155 {
156  PrintClassDefn<typename std::remove_pointer<T>::type>(d);
157 }
158 
159 } // namespace python
160 } // namespace bindings
161 } // namespace mlpack
162 
163 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
python
Definition: CMakeLists.txt:7
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:52
void PrintClassDefn(util::ParamData &, const typename std::enable_if<!arma::is_arma_type< T >::value >::type *=0, const typename std::enable_if<!data::HasSerialize< T >::value >::type *=0)
Non-serializable models don&#39;t require any special definitions, so this prints nothing.
void StripType(const std::string &inputType, std::string &strippedType, std::string &printedType, std::string &defaultsType)
Given an input type like, e.g., "LogisticRegression<>", return three types that can be used in Python...
Definition: strip_type.hpp:28
std::string cppType
The true name of the type, as it would be written in C++.
Definition: param_data.hpp:81