highway.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
14 #define MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "../visitor/delete_visitor.hpp"
19 #include "../visitor/delta_visitor.hpp"
20 #include "../visitor/output_height_visitor.hpp"
21 #include "../visitor/output_parameter_visitor.hpp"
22 #include "../visitor/output_width_visitor.hpp"
23 
24 #include "layer_types.hpp"
25 #include "add_merge.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
54 template <
55  typename InputDataType = arma::mat,
56  typename OutputDataType = arma::mat,
57  typename... CustomLayers>
58 class Highway
59 {
60  public:
62  Highway();
63 
70  Highway(const size_t inSize, const bool model = true);
71 
73  ~Highway();
74 
78  void Reset();
79 
87  template<typename eT>
88  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
89 
99  template<typename eT>
100  void Backward(const arma::Mat<eT>& /* input */,
101  const arma::Mat<eT>& gy,
102  arma::Mat<eT>& g);
103 
111  template<typename eT>
112  void Gradient(const arma::Mat<eT>& input,
113  const arma::Mat<eT>& error,
114  arma::Mat<eT>& gradient);
115 
121  template <class LayerType, class... Args>
122  void Add(Args... args)
123  {
124  network.push_back(new LayerType(args...));
125  networkOwnerships.push_back(true);
126  }
127 
134  {
135  network.push_back(layer);
136  networkOwnerships.push_back(false);
137  }
138 
140  std::vector<LayerTypes<CustomLayers...> >& Model()
141  {
142  if (model)
143  {
144  return network;
145  }
146 
147  return empty;
148  }
149 
151  OutputDataType const& Parameters() const { return weights; }
153  OutputDataType& Parameters() { return weights; }
154 
156  InputDataType const& InputParameter() const { return inputParameter; }
158  InputDataType& InputParameter() { return inputParameter; }
159 
161  OutputDataType const& OutputParameter() const { return outputParameter; }
163  OutputDataType& OutputParameter() { return outputParameter; }
164 
166  OutputDataType const& Delta() const { return delta; }
168  OutputDataType& Delta() { return delta; }
169 
171  OutputDataType const& Gradient() const { return gradient; }
173  OutputDataType& Gradient() { return gradient; }
174 
176  size_t InSize() const { return inSize; }
177 
179  size_t InputShape() const
180  {
181  return inSize;
182  }
183 
187  template<typename Archive>
188  void serialize(Archive& ar, const uint32_t /* version */);
189 
190  private:
192  size_t inSize;
193 
195  bool model;
196 
198  bool reset;
199 
201  std::vector<LayerTypes<CustomLayers...> > network;
202 
204  std::vector<bool> networkOwnerships;
205 
207  std::vector<LayerTypes<CustomLayers...> > empty;
208 
210  OutputDataType weights;
211 
213  OutputDataType delta;
214 
216  OutputDataType gradient;
217 
219  OutputDataType transformWeight;
220 
222  OutputDataType transformBias;
223 
225  OutputDataType transformGate;
226 
228  OutputDataType transformGateActivation;
229 
231  OutputDataType transformGateError;
232 
234  InputDataType inputParameter;
235 
237  OutputDataType outputParameter;
238 
240  size_t width;
241 
243  size_t height;
244 
246  OutputDataType networkOutput;
247 
249  DeltaVisitor deltaVisitor;
250 
252  OutputParameterVisitor outputParameterVisitor;
253 
255  DeleteVisitor deleteVisitor;
256 
258  OutputWidthVisitor outputWidthVisitor;
259 
261  OutputHeightVisitor outputHeightVisitor;
262 }; // class Highway
263 
264 } // namespace ann
265 } // namespace mlpack
266 
267 // Include implementation.
268 #include "highway_impl.hpp"
269 
270 #endif
DeleteVisitor executes the destructor of the instantiated object.
OutputHeightVisitor exposes the OutputHeight() method of the given module.
Linear algebra utility functions, generally performed on matrices or vectors.
InputDataType & InputParameter()
Modify the input parameter.
Definition: highway.hpp:158
~Highway()
Destroy the Highway object.
void Add(LayerTypes< CustomLayers... > layer)
Add a new module to the model.
Definition: highway.hpp:133
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType const & Delta() const
Get the delta.
Definition: highway.hpp:166
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: highway.hpp:163
OutputDataType & Parameters()
Modify the parameters.
Definition: highway.hpp:153
size_t InputShape() const
Get the shape of the input.
Definition: highway.hpp:179
OutputDataType & Delta()
Modify the delta.
Definition: highway.hpp:168
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: highway.hpp:161
OutputDataType const & Parameters() const
Get the parameters.
Definition: highway.hpp:151
Implementation of the Highway layer.
Definition: highway.hpp:58
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
OutputParameterVisitor exposes the output parameter of the given module.
void Add(Args... args)
Add a new module to the model.
Definition: highway.hpp:122
OutputDataType const & Gradient() const
Get the gradient.
Definition: highway.hpp:171
OutputDataType & Gradient()
Modify the gradient.
Definition: highway.hpp:173
DeltaVisitor exposes the delta parameter of the given module.
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
void Reset()
Reset the layer parameter.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Sequential< arma::mat, arma::mat, false > *, Sequential< arma::mat, arma::mat, true > *, Softmax< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
OutputWidthVisitor exposes the OutputWidth() method of the given module.
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
std::vector< LayerTypes< CustomLayers... > > & Model()
Return the modules of the model.
Definition: highway.hpp:140
InputDataType const & InputParameter() const
Get the input parameter.
Definition: highway.hpp:156
Highway()
Create the Highway object.
size_t InSize() const
Get the number of input units.
Definition: highway.hpp:176