rbm.hpp
Go to the documentation of this file.
1 
11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP
12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP
13 
14 #include <mlpack/core.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
33 template<
34  typename InitializationRuleType,
35  typename DataType = arma::mat,
36  typename PolicyType = BinaryRBM
37 >
38 class RBM
39 {
40  public:
42  typedef typename DataType::elem_type ElemType;
43 
60  RBM(arma::Mat<ElemType> predictors,
61  InitializationRuleType initializeRule,
62  const size_t visibleSize,
63  const size_t hiddenSize,
64  const size_t batchSize = 1,
65  const size_t numSteps = 1,
66  const size_t negSteps = 1,
67  const size_t poolSize = 2,
68  const ElemType slabPenalty = 8,
69  const ElemType radius = 1,
70  const bool persistence = false);
71 
72  // Reset the network.
73  template<typename Policy = PolicyType, typename InputType = DataType>
74  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
75  Reset();
76 
77  // Reset the network.
78  template<typename Policy = PolicyType, typename InputType = DataType>
79  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
80  Reset();
81 
97  template<typename OptimizerType, typename... CallbackType>
98  double Train(OptimizerType& optimizer, CallbackType&&... callbacks);
99 
108  double Evaluate(const arma::Mat<ElemType>& parameters,
109  const size_t i,
110  const size_t batchSize);
111 
119  template<typename Policy = PolicyType, typename InputType = DataType>
120  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, double>::type
121  FreeEnergy(const arma::Mat<ElemType>& input);
122 
133  template<typename Policy = PolicyType, typename InputType = DataType>
134  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
135  double>::type
136  FreeEnergy(const arma::Mat<ElemType>& input);
137 
144  template<typename Policy = PolicyType, typename InputType = DataType>
145  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
146  Phase(const InputType& input, DataType& gradient);
147 
154  template<typename Policy = PolicyType, typename InputType = DataType>
155  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
156  Phase(const InputType& input, DataType& gradient);
157 
165  template<typename Policy = PolicyType, typename InputType = DataType>
166  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
167  SampleHidden(const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
168 
179  template<typename Policy = PolicyType, typename InputType = DataType>
180  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
181  SampleHidden(const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
182 
190  template<typename Policy = PolicyType, typename InputType = DataType>
191  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
192  SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
193 
204  template<typename Policy = PolicyType, typename InputType = DataType>
205  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
206  SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
207 
214  template<typename Policy = PolicyType, typename InputType = DataType>
215  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
216  VisibleMean(InputType& input, DataType& output);
217 
226  template<typename Policy = PolicyType, typename InputType = DataType>
227  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
228  VisibleMean(InputType& input, DataType& output);
229 
236  template<typename Policy = PolicyType, typename InputType = DataType>
237  typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, void>::type
238  HiddenMean(const InputType& input, DataType& output);
239 
250  template<typename Policy = PolicyType, typename InputType = DataType>
251  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
252  HiddenMean(const InputType& input, DataType& output);
253 
262  template<typename Policy = PolicyType, typename InputType = DataType>
263  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
264  SpikeMean(const InputType& visible, DataType& spikeMean);
265 
271  template<typename Policy = PolicyType, typename InputType = DataType>
272  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
273  SampleSpike(InputType& spikeMean, DataType& spike);
274 
284  template<typename Policy = PolicyType, typename InputType = DataType>
285  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
286  SlabMean(const DataType& visible, DataType& spike, DataType& slabMean);
287 
298  template<typename Policy = PolicyType, typename InputType = DataType>
299  typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, void>::type
300  SampleSlab(InputType& slabMean, DataType& slab);
301 
309  void Gibbs(const arma::Mat<ElemType>& input,
310  arma::Mat<ElemType>& output,
311  const size_t steps = SIZE_MAX);
312 
321  void Gradient(const arma::Mat<ElemType>& parameters,
322  const size_t i,
323  arma::Mat<ElemType>& gradient,
324  const size_t batchSize);
325 
330  void Shuffle();
331 
333  size_t NumFunctions() const { return numFunctions; }
334 
336  size_t NumSteps() const { return numSteps; }
337 
339  const arma::Mat<ElemType>& Parameters() const { return parameter; }
341  arma::Mat<ElemType>& Parameters() { return parameter; }
342 
344  arma::Cube<ElemType> const& Weight() const { return weight; }
346  arma::Cube<ElemType>& Weight() { return weight; }
347 
349  DataType const& VisibleBias() const { return visibleBias; }
351  DataType& VisibleBias() { return visibleBias; }
352 
354  DataType const& HiddenBias() const { return hiddenBias; }
356  DataType& HiddenBias() { return hiddenBias; }
357 
359  DataType const& SpikeBias() const { return spikeBias; }
361  DataType& SpikeBias() { return spikeBias; }
362 
364  ElemType const& SlabPenalty() const { return 1.0 / slabPenalty; }
365 
367  DataType const& VisiblePenalty() const { return visiblePenalty; }
369  DataType& VisiblePenalty() { return visiblePenalty; }
370 
372  size_t const& VisibleSize() const { return visibleSize; }
374  size_t const& HiddenSize() const { return hiddenSize; }
376  size_t const& PoolSize() const { return poolSize; }
377 
379  template<typename Archive>
380  void serialize(Archive& ar, const uint32_t version);
381 
382  private:
384  arma::Mat<ElemType> parameter;
386  arma::Mat<ElemType> predictors;
387  // Initializer for initializing the weights of the network.
388  InitializationRuleType initializeRule;
390  arma::Mat<ElemType> state;
392  size_t numFunctions;
394  size_t visibleSize;
396  size_t hiddenSize;
398  size_t batchSize;
400  size_t numSteps;
402  size_t negSteps;
404  size_t poolSize;
406  size_t steps;
408  arma::Cube<ElemType> weight;
410  DataType visibleBias;
412  DataType hiddenBias;
414  DataType preActivation;
416  DataType spikeBias;
418  DataType visiblePenalty;
420  DataType visibleMean;
422  DataType spikeMean;
424  DataType spikeSamples;
426  DataType slabMean;
428  ElemType slabPenalty;
430  ElemType radius;
432  arma::Mat<ElemType> hiddenReconstruction;
434  arma::Mat<ElemType> visibleReconstruction;
436  arma::Mat<ElemType> negativeSamples;
438  arma::Mat<ElemType> negativeGradient;
440  arma::Mat<ElemType> tempNegativeGradient;
442  arma::Mat<ElemType> positiveGradient;
444  arma::Mat<ElemType> gibbsTemporary;
446  bool persistence;
448  bool reset;
449 };
450 
451 } // namespace ann
452 } // namespace mlpack
453 
454 #include "rbm_impl.hpp"
455 #include "spike_slab_rbm_impl.hpp"
456 
457 #endif
DataType const & HiddenBias() const
Return the hidden bias of the network.
Definition: rbm.hpp:354
void Shuffle()
Shuffle the order of function visitation.
DataType & VisibleBias()
Modify the visible bias of the network.
Definition: rbm.hpp:351
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Linear algebra utility functions, generally performed on matrices or vectors.
size_t const & VisibleSize() const
Get the visible size.
Definition: rbm.hpp:372
void Gradient(const arma::Mat< ElemType > &parameters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
Definition: rbm.hpp:361
DataType::elem_type ElemType
Definition: rbm.hpp:42
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
double Evaluate(const arma::Mat< ElemType > &parameters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
DataType & HiddenBias()
Modify the hidden bias of the network.
Definition: rbm.hpp:356
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
Definition: rbm.hpp:346
The implementation of the RBM module.
Definition: rbm.hpp:38
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
Definition: rbm.hpp:359
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
DataType const & VisibleBias() const
Return the visible bias of the network.
Definition: rbm.hpp:349
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
Definition: rbm.hpp:369
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
void serialize(Archive &ar, const uint32_t version)
Serialize the model.
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
Definition: rbm.hpp:341
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
Definition: rbm.hpp:339
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Definition: rbm.hpp:364
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
Definition: rbm.hpp:344
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
Definition: rbm.hpp:367
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
Definition: rbm.hpp:336
size_t const & HiddenSize() const
Get the hidden size.
Definition: rbm.hpp:374
size_t const & PoolSize() const
Get the pool size.
Definition: rbm.hpp:376
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
Definition: rbm.hpp:333
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.