13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP 14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_FFT_CONVOLUTION_HPP 36 template<
typename BorderMode = FullConvolution, const
bool padLastDim = false>
51 template<
typename eT,
typename Border = BorderMode>
52 static typename std::enable_if<
53 std::is_same<Border, ValidConvolution>::value,
void>::type
55 const arma::Mat<eT>& filter,
56 arma::Mat<eT>& output)
58 arma::Mat<eT> inputPadded = input;
59 arma::Mat<eT> filterPadded = filter;
62 inputPadded.resize(inputPadded.n_rows, inputPadded.n_cols + 1);
65 filterPadded.resize(inputPadded.n_rows, inputPadded.n_cols);
67 arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
72 output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
73 input.n_rows - 1, input.n_cols - 1);
86 template<
typename eT,
typename Border = BorderMode>
87 static typename std::enable_if<
88 std::is_same<Border, FullConvolution>::value,
void>::type
90 const arma::Mat<eT>& filter,
91 arma::Mat<eT>& output)
96 const size_t outputRows = input.n_rows + 2 * (filter.n_rows - 1);
97 size_t outputCols = input.n_cols + 2 * (filter.n_cols - 1);
103 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
105 inputPadded.submat(filter.n_rows - 1, filter.n_cols - 1,
106 filter.n_rows - 1 + input.n_rows - 1,
107 filter.n_cols - 1 + input.n_cols - 1) = input;
109 arma::Mat<eT> filterPadded = filter;
110 filterPadded.resize(outputRows, outputCols);
113 arma::Mat<eT> temp = arma::real(ifft2(arma::fft2(inputPadded) % arma::fft2(
118 output = temp.submat(filter.n_rows - 1, filter.n_cols - 1,
119 2 * (filter.n_rows - 1) + input.n_rows - 1,
120 2 * (filter.n_cols - 1) + input.n_cols - 1);
134 template<
typename eT>
136 const arma::Cube<eT>& filter,
137 arma::Cube<eT>& output)
139 arma::Mat<eT> convOutput;
143 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
145 output.slice(0) = convOutput;
147 for (
size_t i = 1; i < input.n_slices; ++i)
165 template<
typename eT>
167 const arma::Cube<eT>& filter,
168 arma::Cube<eT>& output)
170 arma::Mat<eT> convOutput;
174 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
176 output.slice(0) = convOutput;
178 for (
size_t i = 1; i < filter.n_slices; ++i)
193 template<
typename eT>
195 const arma::Mat<eT>& filter,
196 arma::Cube<eT>& output)
198 arma::Mat<eT> convOutput;
202 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
204 output.slice(0) = convOutput;
206 for (
size_t i = 1; i < input.n_slices; ++i)
Linear algebra utility functions, generally performed on matrices or vectors.
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output)
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output)
Computes the two-dimensional convolution through fft.
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output)