svd_convolution.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP
14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_SVD_CONVOLUTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "border_modes.hpp"
18 #include "fft_convolution.hpp"
19 #include "naive_convolution.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
37 template<typename BorderMode = FullConvolution>
39 {
40  public:
41  /*
42  * Perform a convolution (valid or full mode) using singular value
43  * decomposition. By using singular value decomposition of the filter matrix
44  * the convolution can be expressed as a sum of outer products. Each product
45  * can be computed efficiently as convolution with a row and a column vector.
46  * The individual convolutions are computed with the naive implementation
47  * which is fast if the filter is low-dimensional.
48  *
49  * @param input Input used to perform the convolution.
50  * @param filter Filter used to perform the conolution.
51  * @param output Output data that contains the results of the convolution.
52  * @param dW Stride of filter application in the x direction.
53  * @param dH Stride of filter application in the y direction.
54  */
55  template<typename eT>
56  static void Convolution(const arma::Mat<eT>& input,
57  const arma::Mat<eT>& filter,
58  arma::Mat<eT>& output)
59  {
60  // Use the naive convolution in case the filter isn't two dimensional or the
61  // filter is bigger than the input.
62  if (filter.n_rows > input.n_rows || filter.n_cols > input.n_cols ||
63  filter.n_rows == 1 || filter.n_cols == 1)
64  {
65  NaiveConvolution<BorderMode>::Convolution(input, filter, output);
66  }
67  else
68  {
69  arma::Mat<eT> U, V, subOutput;
70  arma::Col<eT> s;
71 
72  arma::svd_econ(U, s, V, filter);
73 
74  // Rank approximation using the singular values calculated with singular
75  // value decomposition of dense filter matrix.
76  const size_t rank = arma::sum(s > (s.n_elem * arma::max(s) *
77  arma::datum::eps));
78 
79  // Test for separability based on the rank of the kernel and take
80  // advantage of the low rank.
81  if (rank * (filter.n_rows + filter.n_cols) < filter.n_elem)
82  {
83  arma::Mat<eT> subFilter = V.unsafe_col(0) * s(0);
84  NaiveConvolution<BorderMode>::Convolution(input, subFilter, subOutput);
85 
86  subOutput = subOutput.t();
87  NaiveConvolution<BorderMode>::Convolution(subOutput, U.unsafe_col(0),
88  output);
89 
90  arma::Mat<eT> temp;
91  for (size_t r = 1; r < rank; r++)
92  {
93  subFilter = V.unsafe_col(r) * s(r);
95  subOutput);
96 
97  subOutput = subOutput.t();
98  NaiveConvolution<BorderMode>::Convolution(subOutput, U.unsafe_col(r),
99  temp);
100  output += temp;
101  }
102 
103  output = output.t();
104  }
105  else
106  {
107  FFTConvolution<BorderMode>::Convolution(input, filter, output);
108  }
109  }
110  }
111 
112  /*
113  * Perform a convolution using 3rd order tensors.
114  *
115  * @param input Input used to perform the convolution.
116  * @param filter Filter used to perform the conolution.
117  * @param output Output data that contains the results of the convolution.
118  * @param dW Stride of filter application in the x direction.
119  * @param dH Stride of filter application in the y direction.
120  */
121  template<typename eT>
122  static void Convolution(const arma::Cube<eT>& input,
123  const arma::Cube<eT>& filter,
124  arma::Cube<eT>& output)
125  {
126  arma::Mat<eT> convOutput;
127  SVDConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
128  convOutput);
129 
130  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
131  input.n_slices);
132  output.slice(0) = convOutput;
133 
134  for (size_t i = 1; i < input.n_slices; ++i)
135  {
136  SVDConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
137  output.slice(i));
138  }
139  }
140 
141  /*
142  * Perform a convolution using dense matrix as input and a 3rd order tensors
143  * as filter and output.
144  *
145  * @param input Input used to perform the convolution.
146  * @param filter Filter used to perform the conolution.
147  * @param output Output data that contains the results of the convolution.
148  * @param dW Stride of filter application in the x direction.
149  * @param dH Stride of filter application in the y direction.
150  */
151  template<typename eT>
152  static void Convolution(const arma::Mat<eT>& input,
153  const arma::Cube<eT>& filter,
154  arma::Cube<eT>& output)
155  {
156  arma::Mat<eT> convOutput;
157  SVDConvolution<BorderMode>::Convolution(input, filter.slice(0), convOutput);
158 
159  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
160  filter.n_slices);
161  output.slice(0) = convOutput;
162 
163  for (size_t i = 1; i < filter.n_slices; ++i)
164  {
165  SVDConvolution<BorderMode>::Convolution(input, filter.slice(i),
166  output.slice(i));
167  }
168  }
169 
170  /*
171  * Perform a convolution using a 3rd order tensors as input and output and a
172  * dense matrix as filter.
173  *
174  * @param input Input used to perform the convolution.
175  * @param filter Filter used to perform the conolution.
176  * @param output Output data that contains the results of the convolution.
177  * @param dW Stride of filter application in the x direction.
178  * @param dH Stride of filter application in the y direction.
179  */
180  template<typename eT>
181  static void Convolution(const arma::Cube<eT>& input,
182  const arma::Mat<eT>& filter,
183  arma::Cube<eT>& output)
184  {
185  arma::Mat<eT> convOutput;
186  SVDConvolution<BorderMode>::Convolution(input.slice(0), filter, convOutput);
187 
188  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
189  input.n_slices);
190  output.slice(0) = convOutput;
191 
192  for (size_t i = 1; i < input.n_slices; ++i)
193  {
194  SVDConvolution<BorderMode>::Convolution(input.slice(i), filter,
195  output.slice(i));
196  }
197  }
198 }; // class SVDConvolution
199 
200 } // namespace ann
201 } // namespace mlpack
202 
203 #endif
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution using singular value decomposition.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)