lookup.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LOOKUP_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LOOKUP_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace ann /* Artificial Neural Network. */ {
21 
37 template <
38  typename InputDataType = arma::mat,
39  typename OutputDataType = arma::mat
40 >
41 class Lookup
42 {
43  public:
50  Lookup(const size_t vocabSize = 0, const size_t embeddingSize = 0);
51 
59  template<typename eT>
60  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
61 
71  template<typename eT>
72  void Backward(const arma::Mat<eT>& /* input */,
73  const arma::Mat<eT>& gy,
74  arma::Mat<eT>& g);
75 
83  template<typename eT>
84  void Gradient(const arma::Mat<eT>& input,
85  const arma::Mat<eT>& error,
86  arma::Mat<eT>& gradient);
87 
89  OutputDataType const& Parameters() const { return weights; }
91  OutputDataType& Parameters() { return weights; }
92 
94  OutputDataType const& OutputParameter() const { return outputParameter; }
96  OutputDataType& OutputParameter() { return outputParameter; }
97 
99  OutputDataType const& Delta() const { return delta; }
101  OutputDataType& Delta() { return delta; }
102 
104  OutputDataType const& Gradient() const { return gradient; }
106  OutputDataType& Gradient() { return gradient; }
107 
109  size_t VocabSize() const { return vocabSize; }
110 
112  size_t EmbeddingSize() const { return embeddingSize; }
113 
117  template<typename Archive>
118  void serialize(Archive& ar, const uint32_t /* version */);
119 
120  private:
122  size_t vocabSize;
123 
125  size_t embeddingSize;
126 
128  OutputDataType weights;
129 
131  OutputDataType delta;
132 
134  OutputDataType gradient;
135 
137  OutputDataType outputParameter;
138 }; // class Lookup
139 
140 // Alias for using as embedding layer.
141 template<typename MatType = arma::mat>
143 
144 } // namespace ann
145 } // namespace mlpack
146 
147 // Include implementation.
148 #include "lookup_impl.hpp"
149 
150 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
OutputDataType const & Gradient() const
Get the gradient.
Definition: lookup.hpp:104
size_t VocabSize() const
Get the size of the vocabulary.
Definition: lookup.hpp:109
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType & Parameters()
Modify the parameters.
Definition: lookup.hpp:91
Lookup(const size_t vocabSize=0, const size_t embeddingSize=0)
Create the Lookup object using the specified vocabulary and embedding size.
OutputDataType const & Delta() const
Get the delta.
Definition: lookup.hpp:99
size_t EmbeddingSize() const
Get the length of each embedding vector.
Definition: lookup.hpp:112
The Lookup class stores word embeddings and retrieves them using tokens.
Definition: lookup.hpp:41
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: lookup.hpp:96
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType const & Parameters() const
Get the parameters.
Definition: lookup.hpp:89
OutputDataType & Gradient()
Modify the gradient.
Definition: lookup.hpp:106
OutputDataType & Delta()
Modify the delta.
Definition: lookup.hpp:101
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: lookup.hpp:94