fft_convolution.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "border_modes.hpp"
18 
19 namespace mlpack {
20 namespace ann {
21 
36 template<typename BorderMode = FullConvolution, const bool padLastDim = false>
38 {
39  public:
40  /*
41  * Perform a convolution through fft (valid mode). This method only supports
42  * input which is even on the last dimension. In case of an odd input width, a
43  * user can manually pad the input or specify the padLastDim parameter which
44  * takes care of the padding. The filter instead can have any size. When using
45  * the valid mode the filter has to be smaller than the input.
46  *
47  * @param input Input used to perform the convolution.
48  * @param filter Filter used to perform the convolution.
49  * @param output Output data that contains the results of the convolution.
50  */
51  template<typename eT, typename Border = BorderMode>
52  static typename std::enable_if<
53  std::is_same<Border, ValidConvolution>::value, void>::type
54  Convolution(const arma::Mat<eT>& input,
55  const arma::Mat<eT>& filter,
56  arma::Mat<eT>& output)
57  {
58  arma::Mat<eT> inputPadded = input;
59  arma::Mat<eT> filterPadded = filter;
60 
61  if (padLastDim)
62  inputPadded.resize(inputPadded.n_rows, inputPadded.n_cols + 1);
63 
64  // Pad filter and input to the output shape.
65  filterPadded.resize(inputPadded.n_rows, inputPadded.n_cols);
66 
67  arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
68  filterPadded)));
69 
70  // Extract the region of interest. We don't need to handle the padLastDim in
71  // a special way we just cut it out from the output matrix.
72  output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
73  input.n_rows - 1, input.n_cols - 1);
74  }
75 
76  /*
77  * Perform a convolution through fft (full mode). This method only supports
78  * input which is even on the last dimension. In case of an odd input width, a
79  * user can manually pad the input or specify the padLastDim parameter which
80  * takes care of the padding. The filter instead can have any size.
81  *
82  * @param input Input used to perform the convolution.
83  * @param filter Filter used to perform the convolution.
84  * @param output Output data that contains the results of the convolution.
85  */
86  template<typename eT, typename Border = BorderMode>
87  static typename std::enable_if<
88  std::is_same<Border, FullConvolution>::value, void>::type
89  Convolution(const arma::Mat<eT>& input,
90  const arma::Mat<eT>& filter,
91  arma::Mat<eT>& output)
92  {
93  // In case of the full convolution outputRows and outputCols doesn't
94  // represent the true output size when the padLastDim parameter is set,
95  // instead it's the working size.
96  const size_t outputRows = input.n_rows + 2 * (filter.n_rows - 1);
97  size_t outputCols = input.n_cols + 2 * (filter.n_cols - 1);
98 
99  if (padLastDim)
100  outputCols++;
101 
102  // Pad filter and input to the working output shape.
103  arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
104  outputCols);
105  inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
106  filter.n_rows - 1 + input.n_rows - 1,
107  filter.n_cols - 1 + input.n_cols - 1) = input;
108 
109  arma::Mat<eT> filterPadded = filter;
110  filterPadded.resize(outputRows, outputCols);
111 
112  // Perform FFT and IFFT
113  arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
114  filterPadded)));
115 
116  // Extract the region of interest. We don't need to handle the padLastDim
117  // parameter in a special way we just cut it out from the output matrix.
118  output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
119  2 * (filter.n_rows - 1) + input.n_rows - 1,
120  2 * (filter.n_cols - 1) + input.n_cols - 1);
121  }
122 
123  /*
124  * Perform a convolution through fft using 3rd order tensors. This method only
125  * supports input which is even on the last dimension. In case of an odd input
126  * width, a user can manually pad the input or specify the padLastDim
127  * parameter which takes care of the padding. The filter instead can have any
128  * size.
129  *
130  * @param input Input used to perform the convolution.
131  * @param filter Filter used to perform the convolution.
132  * @param output Output data that contains the results of the convolution.
133  */
134  template<typename eT>
135  static void Convolution(const arma::Cube<eT>& input,
136  const arma::Cube<eT>& filter,
137  arma::Cube<eT>& output)
138  {
139  arma::Mat<eT> convOutput;
140  FFTConvolution<BorderMode>::Convolution(input.slice(0), filter.slice(0),
141  convOutput);
142 
143  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
144  input.n_slices);
145  output.slice(0) = convOutput;
146 
147  for (size_t i = 1; i < input.n_slices; ++i)
148  {
149  FFTConvolution<BorderMode>::Convolution(input.slice(i), filter.slice(i),
150  output.slice(i));
151  }
152  }
153 
154  /*
155  * Perform a convolution through fft using dense matrix as input and a 3rd
156  * order tensors as filter and output. This method only supports input which
157  * is even on the last dimension. In case of an odd input width, a user can
158  * manually pad the input or specify the padLastDim parameter which takes care
159  * of the padding. The filter instead can have any size.
160  *
161  * @param input Input used to perform the convolution.
162  * @param filter Filter used to perform the convolution.
163  * @param output Output data that contains the results of the convolution.
164  */
165  template<typename eT>
166  static void Convolution(const arma::Mat<eT>& input,
167  const arma::Cube<eT>& filter,
168  arma::Cube<eT>& output)
169  {
170  arma::Mat<eT> convOutput;
171  FFTConvolution<BorderMode>::Convolution(input, filter.slice(0),
172  convOutput);
173 
174  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
175  filter.n_slices);
176  output.slice(0) = convOutput;
177 
178  for (size_t i = 1; i < filter.n_slices; ++i)
179  {
180  FFTConvolution<BorderMode>::Convolution(input, filter.slice(i),
181  output.slice(i));
182  }
183  }
184 
185  /*
186  * Perform a convolution using a 3rd order tensors as input and output and a
187  * dense matrix as filter.
188  *
189  * @param input Input used to perform the convolution.
190  * @param filter Filter used to perform the convolution.
191  * @param output Output data that contains the results of the convolution.
192  */
193  template<typename eT>
194  static void Convolution(const arma::Cube<eT>& input,
195  const arma::Mat<eT>& filter,
196  arma::Cube<eT>& output)
197  {
198  arma::Mat<eT> convOutput;
199  FFTConvolution<BorderMode>::Convolution(input.slice(0), filter,
200  convOutput);
201 
202  output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
203  input.n_slices);
204  output.slice(0) = convOutput;
205 
206  for (size_t i = 1; i < input.n_slices; ++i)
207  {
208  FFTConvolution<BorderMode>::Convolution(input.slice(i), filter,
209  output.slice(i));
210  }
211  }
212 }; // class FFTConvolution
213 
214 } // namespace ann
215 } // namespace mlpack
216 
217 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution through fft.
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)