13 #ifndef MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP 14 #define MLPACK_METHODS_ANN_CONVOLUTION_RULES_NAIVE_CONVOLUTION_HPP 34 template<
typename BorderMode = FullConvolution>
49 template<
typename eT,
typename Border = BorderMode>
50 static typename std::enable_if<
51 std::is_same<Border, ValidConvolution>::value,
void>::type
53 const arma::Mat<eT>& filter,
54 arma::Mat<eT>& output,
57 const size_t dilationW = 1,
58 const size_t dilationH = 1)
60 output = arma::zeros<arma::Mat<eT> >(
61 (input.n_rows - (filter.n_rows - 1) * dilationW - 1) / dW + 1,
62 (input.n_cols - (filter.n_cols - 1) * dilationH - 1) / dH + 1);
66 eT* outputPtr = output.memptr();
68 for (
size_t j = 0; j < output.n_cols; ++j)
70 for (
size_t i = 0; i < output.n_rows; ++i, outputPtr++)
72 const eT* kernelPtr = filter.memptr();
73 for (
size_t kj = 0; kj < filter.n_cols; ++kj)
75 const eT* inputPtr = input.colptr(kj * dilationW + j * dW) + i * dH;
76 for (
size_t ki = 0; ki < filter.n_rows; ++ki, ++kernelPtr,
77 inputPtr += dilationH)
78 *outputPtr += *kernelPtr * (*inputPtr);
95 template<
typename eT,
typename Border = BorderMode>
96 static typename std::enable_if<
97 std::is_same<Border, FullConvolution>::value,
void>::type
99 const arma::Mat<eT>& filter,
100 arma::Mat<eT>& output,
103 const size_t dilationW = 1,
104 const size_t dilationH = 1)
106 size_t outputRows = (input.n_rows - 1) * dW + 2 * (filter.n_rows - 1)
108 size_t outputCols = (input.n_cols - 1) * dH + 2 * (filter.n_cols - 1)
111 for (
size_t i = 0; i < dW; ++i)
113 if (((((i + outputRows - 2 * (filter.n_rows - 1) * dilationW - 1) % dW)
119 for (
size_t i = 0; i < dH; ++i)
121 if (((((i + outputCols - 2 * (filter.n_cols - 1) * dilationH - 1) % dH)
129 arma::Mat<eT> inputPadded = arma::zeros<arma::Mat<eT> >(outputRows,
131 inputPadded.submat((filter.n_rows - 1) * dilationW, (filter.n_cols - 1)
132 * dilationH, (filter.n_rows - 1) * dilationW + input.n_rows - 1,
133 (filter.n_cols - 1) * dilationH + input.n_cols - 1) = input;
136 output, 1, 1, dilationW, dilationH);
150 template<
typename eT>
152 const arma::Cube<eT>& filter,
153 arma::Cube<eT>& output,
156 const size_t dilationW = 1,
157 const size_t dilationH = 1)
159 arma::Mat<eT> convOutput;
161 convOutput, dW, dH, dilationW, dilationH);
163 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
165 output.slice(0) = convOutput;
167 for (
size_t i = 1; i < input.n_slices; ++i)
170 output.slice(i), dW, dH, dilationW, dilationH);
186 template<
typename eT>
188 const arma::Cube<eT>& filter,
189 arma::Cube<eT>& output,
192 const size_t dilationW = 1,
193 const size_t dilationH = 1)
195 arma::Mat<eT> convOutput;
197 convOutput, dW, dH, dilationW, dilationH);
199 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
201 output.slice(0) = convOutput;
203 for (
size_t i = 1; i < filter.n_slices; ++i)
206 output.slice(i), dW, dH, dilationW, dilationH);
222 template<
typename eT>
224 const arma::Mat<eT>& filter,
225 arma::Cube<eT>& output,
228 const size_t dilationW = 1,
229 const size_t dilationH = 1)
231 arma::Mat<eT> convOutput;
233 convOutput, dW, dH, dilationW, dilationH);
235 output = arma::Cube<eT>(convOutput.n_rows, convOutput.n_cols,
237 output.slice(0) = convOutput;
239 for (
size_t i = 1; i < input.n_slices; ++i)
242 output.slice(i), dW, dH, dilationW, dilationH);
static std::enable_if< std::is_same< Border, FullConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static void Convolution(const arma::Cube< eT > &input, const arma::Mat< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Convolution(const arma::Mat< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
static void Convolution(const arma::Cube< eT > &input, const arma::Cube< eT > &filter, arma::Cube< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)
Computes the two-dimensional convolution.
static std::enable_if< std::is_same< Border, ValidConvolution >::value, void >::type Convolution(const arma::Mat< eT > &input, const arma::Mat< eT > &filter, arma::Mat< eT > &output, const size_t dW=1, const size_t dH=1, const size_t dilationW=1, const size_t dilationH=1)