test_catch_tools.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_TEST_CATCH_TOOLS_HPP
13 #define MLPACK_TESTS_TEST_CATCH_TOOLS_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 #include "catch.hpp"
18 
19 // Require the approximation L to be within a relative error of E respect to the
20 // actual value R.
21 #define REQUIRE_RELATIVE_ERR(L, R, E) \
22  REQUIRE(std::abs((R) - (L)) <= (E) * std::abs(R))
23 
24 // Check the values of two matrices.
25 inline void CheckMatrices(const arma::mat& a,
26  const arma::mat& b,
27  double tolerance = 1e-5)
28 {
29  REQUIRE(a.n_rows == b.n_rows);
30  REQUIRE(a.n_cols == b.n_cols);
31 
32  for (size_t i = 0; i < a.n_elem; ++i)
33  {
34  if (std::abs(a[i]) < tolerance / 2)
35  REQUIRE(b[i] == Approx(0.0).margin(tolerance / 2));
36  else
37  REQUIRE(a[i] == Approx(b[i]).epsilon(tolerance / 100));
38  }
39 }
40 
41 // Check the values of two unsigned matrices.
42 inline void CheckMatrices(const arma::Mat<size_t>& a,
43  const arma::Mat<size_t>& b)
44 {
45  REQUIRE(a.n_rows == b.n_rows);
46  REQUIRE(a.n_cols == b.n_cols);
47 
48  for (size_t i = 0; i < a.n_elem; ++i)
49  REQUIRE(a[i] == b[i]);
50 }
51 
52 template <typename FieldType,
53  typename = std::enable_if_t<
54  arma::is_arma_type<typename FieldType::object_type>::value>>
55 // Check the values of two field types.
56 inline void CheckFields(const FieldType& a,
57  const FieldType& b)
58 {
59  REQUIRE(a.n_rows == b.n_rows);
60  REQUIRE(a.n_cols == b.n_cols);
61 
62  for (size_t i = 0; i < a.n_slices; ++i)
63  CheckMatrices(a(i), b(i));
64 }
65 
66 // Check the values of two cubes.
67 inline void CheckMatrices(const arma::cube& a,
68  const arma::cube& b,
69  double tolerance = 1e-5)
70 {
71  REQUIRE(a.n_rows == b.n_rows);
72  REQUIRE(a.n_cols == b.n_cols);
73  REQUIRE(a.n_slices == b.n_slices);
74 
75  for (size_t i = 0; i < a.n_elem; ++i)
76  {
77  if (std::abs(a[i]) < tolerance / 2)
78  REQUIRE(b[i] == Approx(0.0).margin(tolerance / 2));
79  else
80  REQUIRE(a[i] == Approx(b[i]).epsilon(tolerance / 100));
81  }
82 }
83 
84 // Check if two matrices are different.
85 inline void CheckMatricesNotEqual(const arma::mat& a,
86  const arma::mat& b,
87  double tolerance = 1e-5)
88 {
89  bool areDifferent = false;
90 
91  // Only check the elements if the dimensions are equal.
92  if (a.n_rows == b.n_rows && a.n_cols == b.n_cols)
93  {
94  for (size_t i = 0; i < a.n_elem; ++i)
95  {
96  if (std::abs(a[i]) < tolerance / 2 &&
97  b[i] > tolerance / 2)
98  {
99  areDifferent = true;
100  break;
101  }
102  else if (std::abs(a[i] - b[i]) > tolerance)
103  {
104  areDifferent = true;
105  break;
106  }
107  }
108  }
109  else
110  areDifferent = true;
111 
112  if (!areDifferent)
113  FAIL("The matrices are equal.");
114 }
115 
116 // Check if two unsigned matrices are different.
117 inline void CheckMatricesNotEqual(const arma::Mat<size_t>& a,
118  const arma::Mat<size_t>& b)
119 {
120  bool areDifferent = false;
121 
122  // Only check the elements if the dimensions are equal.
123  if (a.n_rows == b.n_rows && a.n_cols == b.n_cols)
124  {
125  for (size_t i = 0; i < a.n_elem; ++i)
126  {
127  if (a[i] != b[i])
128  {
129  areDifferent = true;
130  break;
131  }
132  }
133  }
134  else
135  areDifferent = true;
136 
137  if (!areDifferent)
138  FAIL("The matrices are equal.");
139 }
140 
141 // Check if two cubes are different.
142 inline void CheckMatricesNotEqual(const arma::cube& a,
143  const arma::cube& b,
144  double tolerance = 1e-5)
145 {
146  bool areDifferent = false;
147 
148  // Only check the elements if the dimensions are equal.
149  if (a.n_rows == b.n_rows && a.n_cols == b.n_cols &&
150  a.n_slices == b.n_slices)
151  {
152  for (size_t i = 0; i < a.n_elem; ++i)
153  {
154  if (std::abs(a[i]) < tolerance / 2 &&
155  b[i] > tolerance / 2)
156  {
157  areDifferent = true;
158  break;
159  }
160  else if (std::abs(a[i] - b[i]) > tolerance)
161  {
162  areDifferent = true;
163  break;
164  }
165  }
166  }
167  else
168  areDifferent = true;
169 
170  if (!areDifferent)
171  FAIL("The matrices are equal.");
172 }
173 
174 // Filter typeinfo string to generate unique filenames for serialization tests.
175 inline std::string FilterFileName(const std::string& inputString)
176 {
177  // Take the last valid 32 characters for the filename.
178  std::string fileName;
179  for (auto it = inputString.rbegin(); it != inputString.rend() &&
180  fileName.size() != 32; ++it)
181  {
182  if (std::isalnum(*it))
183  fileName.push_back(*it);
184  }
185 
186  return fileName;
187 }
188 
189 #endif
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:70
void CheckMatrices(const arma::mat &a, const arma::mat &b, double tolerance=1e-5)
void CheckFields(const FieldType &a, const FieldType &b)
std::string FilterFileName(const std::string &inputString)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void CheckMatricesNotEqual(const arma::mat &a, const arma::mat &b, double tolerance=1e-5)