serialization.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_SERIALIZATION_CATCH_HPP
13 #define MLPACK_TESTS_SERIALIZATION_CATCH_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 #include "test_catch_tools.hpp"
18 #include "catch.hpp"
19 
20 namespace mlpack {
21 
22 // Test function for loading and saving Armadillo objects.
23 template<typename CubeType,
24  typename IArchiveType,
25  typename OArchiveType>
26 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
27 {
28  // First save it.
29  // Use type_info name to get unique file name for serialization test files.
30  std::string fileName = FilterFileName(typeid(IArchiveType).name());
31  std::ofstream ofs(fileName, std::ios::binary);
32 
33  {
34  OArchiveType o(ofs);
35  o(CEREAL_NVP(x));
36  }
37 
38  ofs.close();
39 
40  // Now load it.
41  arma::Cube<CubeType> orig(x);
42  std::ifstream ifs(fileName, std::ios::binary);
43 
44  {
45  IArchiveType i(ifs);
46  i(CEREAL_NVP(x));
47  }
48  ifs.close();
49 
50  remove(fileName.c_str());
51 
52  REQUIRE(x.n_rows == orig.n_rows);
53  REQUIRE(x.n_cols == orig.n_cols);
54  REQUIRE(x.n_elem_slice == orig.n_elem_slice);
55  REQUIRE(x.n_slices == orig.n_slices);
56  REQUIRE(x.n_elem == orig.n_elem);
57 
58  for (size_t slice = 0; slice != x.n_slices; ++slice)
59  {
60  const auto& origSlice = orig.slice(slice);
61  const auto& xSlice = x.slice(slice);
62  for (size_t i = 0; i < x.n_cols; ++i)
63  {
64  for (size_t j = 0; j < x.n_rows; ++j)
65  {
66  if (double(origSlice(j, i)) == 0.0)
67  REQUIRE(double(xSlice(j, i)) == Approx(0.0).margin(1e-8 / 100));
68  else
69  REQUIRE(double(origSlice(j, i)) ==
70  Approx(double(xSlice(j, i))).epsilon(1e-8 / 100));
71  }
72  }
73  }
74 }
75 
76 // Test all serialization strategies.
77 template<typename CubeType>
78 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
79 {
80  TestArmadilloSerialization<CubeType, cereal::XMLInputArchive,
81  cereal::XMLOutputArchive>(x);
82  TestArmadilloSerialization<CubeType, cereal::JSONInputArchive,
83  cereal::JSONOutputArchive>(x);
84  TestArmadilloSerialization<CubeType, cereal::BinaryInputArchive,
85  cereal::BinaryOutputArchive>(x);
86 }
87 
88 // Test function for loading and saving Armadillo objects.
89 template<typename MatType,
90  typename IArchiveType,
91  typename OArchiveType>
93 {
94  // First save it.
95  std::string fileName = FilterFileName(typeid(IArchiveType).name());
96  std::ofstream ofs(fileName, std::ios::binary);
97 
98  {
99  OArchiveType o(ofs);
100  o(CEREAL_NVP(x));
101  }
102 
103  ofs.close();
104 
105  // Now load it.
106  MatType orig(x);
107  std::ifstream ifs(fileName, std::ios::binary);
108 
109  {
110  IArchiveType i(ifs);
111  i(CEREAL_NVP(x));
112  }
113  ifs.close();
114 
115  remove(fileName.c_str());
116 
117  REQUIRE(x.n_rows == orig.n_rows);
118  REQUIRE(x.n_cols == orig.n_cols);
119  REQUIRE(x.n_elem == orig.n_elem);
120 
121  for (size_t i = 0; i < x.n_cols; ++i)
122  for (size_t j = 0; j < x.n_rows; ++j)
123  if (double(orig(j, i)) == 0.0)
124  REQUIRE(double(x(j, i)) == Approx(0.0).margin(1e-8 / 100));
125  else
126  REQUIRE(double(orig(j, i)) ==
127  Approx(double(x(j, i))).epsilon(1e-8 / 100));
128 }
129 
130 // Test all serialization strategies.
131 template<typename MatType>
133 {
134  TestArmadilloSerialization<MatType, cereal::XMLInputArchive,
135  cereal::XMLOutputArchive>(x);
136  TestArmadilloSerialization<MatType, cereal::JSONInputArchive,
137  cereal::JSONOutputArchive>(x);
138  TestArmadilloSerialization<MatType, cereal::BinaryInputArchive,
139  cereal::BinaryOutputArchive>(x);
140 }
141 
142 // Save and load an mlpack object.
143 // The re-loaded copy is placed in 'newT'.
144 template<typename T, typename IArchiveType, typename OArchiveType>
145 void SerializeObject(T& t, T& newT)
146 {
147  std::string fileName = FilterFileName(typeid(T).name());
148  std::ofstream ofs(fileName, std::ios::binary);
149 
150  {
151  OArchiveType o(ofs);
152 
153  T& x(t);
154  o(CEREAL_NVP(x));
155  }
156  ofs.close();
157 
158  std::ifstream ifs(fileName, std::ios::binary);
159 
160  {
161  IArchiveType i(ifs);
162  T& x(newT);
163  i(CEREAL_NVP(x));
164  }
165  ifs.close();
166 
167  remove(fileName.c_str());
168 }
169 
170 // Test mlpack serialization with all three archive types.
171 template<typename T>
172 void SerializeObjectAll(T& t, T& xmlT, T& jsonT, T& binaryT)
173 {
174  SerializeObject<T, cereal::XMLInputArchive,
175  cereal::XMLOutputArchive>(t, xmlT);
176  SerializeObject<T, cereal::JSONInputArchive,
177  cereal::JSONOutputArchive>(t, jsonT);
178  SerializeObject<T, cereal::BinaryInputArchive,
179  cereal::BinaryOutputArchive>(t, binaryT);
180 }
181 
182 // Save and load a non-default-constructible mlpack object.
183 template<typename T, typename IArchiveType, typename OArchiveType>
184 void SerializePointerObject(T* t, T*& newT)
185 {
186  std::string fileName = FilterFileName(typeid(T).name());
187  std::ofstream ofs(fileName, std::ios::binary);
188 
189  {
190  OArchiveType o(ofs);
191  o(CEREAL_POINTER(t));
192  }
193  ofs.close();
194 
195  std::ifstream ifs(fileName, std::ios::binary);
196 
197  {
198  IArchiveType i(ifs);
199  i(CEREAL_POINTER(newT));
200  }
201  ifs.close();
202  remove(fileName.c_str());
203 }
204 
205 template<typename T>
206 void SerializePointerObjectAll(T* t, T*& xmlT, T*& jsonT, T*& binaryT)
207 {
208  SerializePointerObject<T, cereal::JSONInputArchive,
209  cereal::JSONOutputArchive>(t, jsonT);
210  SerializePointerObject<T, cereal::BinaryInputArchive,
211  cereal::BinaryOutputArchive>(t, binaryT);
212  SerializePointerObject<T, cereal::XMLInputArchive,
213  cereal::XMLOutputArchive>(t, xmlT);
214 }
215 
216 // Utility function to check the equality of two Armadillo matrices.
217 void CheckMatrices(const arma::mat& x,
218  const arma::mat& xmlX,
219  const arma::mat& jsonX,
220  const arma::mat& binaryX);
221 
222 void CheckMatrices(const arma::Mat<size_t>& x,
223  const arma::Mat<size_t>& xmlX,
224  const arma::Mat<size_t>& jsonX,
225  const arma::Mat<size_t>& binaryX);
226 
227 void CheckMatrices(const arma::cube& x,
228  const arma::cube& xmlX,
229  const arma::cube& jsonX,
230  const arma::cube& binaryX);
231 
232 } // namespace mlpack
233 
234 #endif
void SerializePointerObject(T *t, T *&newT)
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &jsonX, const arma::mat &binaryX)
Linear algebra utility functions, generally performed on matrices or vectors.
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&jsonT, T *&binaryT)
std::string FilterFileName(const std::string &inputString)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
void SerializeObjectAll(T &t, T &xmlT, T &jsonT, T &binaryT)