positional_encoding.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_HPP
13 #define MLPACK_METHODS_ANN_LAYER_POSITIONAL_ENCODING_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
33 template <
34  typename InputDataType = arma::mat,
35  typename OutputDataType = arma::mat
36 >
38 {
39  public:
44 
51  PositionalEncoding(const size_t embedDim,
52  const size_t maxSequenceLength);
53 
61  template<typename eT>
62  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
63 
73  template<typename eT>
74  void Backward(const arma::Mat<eT>& /* input */,
75  const arma::Mat<eT>& gy,
76  arma::Mat<eT>& g);
77 
79  InputDataType const& InputParameter() const { return inputParameter; }
81  InputDataType& InputParameter() { return inputParameter; }
82 
84  OutputDataType const& OutputParameter() const { return outputParameter; }
86  OutputDataType& OutputParameter() { return outputParameter; }
87 
89  OutputDataType const& Delta() const { return delta; }
91  OutputDataType& Delta() { return delta; }
92 
94  InputDataType const& Encoding() const { return positionalEncoding; }
95 
96  size_t InputShape() const
97  {
98  return embedDim * maxSequenceLength;
99  }
100 
104  template<typename Archive>
105  void serialize(Archive& ar, const uint32_t /* version */);
106 
107  private:
111  void InitPositionalEncoding();
112 
114  size_t embedDim;
115 
117  size_t maxSequenceLength;
118 
120  InputDataType positionalEncoding;
121 
123  OutputDataType delta;
124 
126  InputDataType inputParameter;
127 
129  OutputDataType outputParameter;
130 }; // class PositionalEncoding
131 
132 } // namespace ann
133 } // namespace mlpack
134 
135 // Include implementation.
136 #include "positional_encoding_impl.hpp"
137 
138 #endif
OutputDataType & OutputParameter()
Modify the output parameter.
Linear algebra utility functions, generally performed on matrices or vectors.
InputDataType const & Encoding() const
Get the positional encoding vector.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Positional Encoding injects some information about the relative or absolute position of the tokens in...
InputDataType & InputParameter()
Modify the input parameter.
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
OutputDataType const & Delta() const
Get the delta.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
PositionalEncoding()
Create PositionalEncoding object.
InputDataType const & InputParameter() const
Get the input parameter.
OutputDataType & Delta()
Modify the delta.
OutputDataType const & OutputParameter() const
Get the output parameter.