11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP    12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP    34   typename InitializationRuleType,
    35   typename DataType = arma::mat,
    36   typename PolicyType = BinaryRBM
    60   RBM(arma::Mat<ElemType> predictors,
    61       InitializationRuleType initializeRule,
    62       const size_t visibleSize,
    63       const size_t hiddenSize,
    64       const size_t batchSize = 1,
    65       const size_t numSteps = 1,
    66       const size_t negSteps = 1,
    67       const size_t poolSize = 2,
    68       const ElemType slabPenalty = 8,
    69       const ElemType radius = 1,
    70       const bool persistence = 
false);
    73   template<
typename Policy = PolicyType, 
typename InputType = DataType>
    74   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
    78   template<
typename Policy = PolicyType, 
typename InputType = DataType>
    79   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
    97   template<
typename OptimizerType, 
typename... CallbackType>
    98   double Train(OptimizerType& optimizer, CallbackType&&... callbacks);
   108   double Evaluate(
const arma::Mat<ElemType>& parameters,
   110                   const size_t batchSize);
   119   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   120   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
double>::type
   133   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   134   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
   144   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   145   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
   146   Phase(
const InputType& input, DataType& gradient);
   154   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   155   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   156   Phase(
const InputType& input, DataType& gradient);
   165   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   166   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
   167   SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
   179   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   180   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   181   SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
   190   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   191   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
   192   SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
   204   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   205   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   206   SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
   214   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   215   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
   226   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   227   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   236   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   237   typename std::enable_if<std::is_same<Policy, BinaryRBM>::value, 
void>::type
   238   HiddenMean(
const InputType& input, DataType& output);
   250   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   251   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   252   HiddenMean(
const InputType& input, DataType& output);
   262   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   263   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   264   SpikeMean(
const InputType& visible, DataType& spikeMean);
   271   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   272   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   273   SampleSpike(InputType& spikeMean, DataType& spike);
   284   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   285   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   286   SlabMean(
const DataType& visible, DataType& spike, DataType& slabMean);
   298   template<
typename Policy = PolicyType, 
typename InputType = DataType>
   299   typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value, 
void>::type
   300   SampleSlab(InputType& slabMean, DataType& slab);
   309   void Gibbs(
const arma::Mat<ElemType>& input,
   310              arma::Mat<ElemType>& output,
   311              const size_t steps = SIZE_MAX);
   321   void Gradient(
const arma::Mat<ElemType>& parameters,
   323                 arma::Mat<ElemType>& gradient,
   324                 const size_t batchSize);
   339   const arma::Mat<ElemType>& 
Parameters()
 const { 
return parameter; }
   344   arma::Cube<ElemType> 
const& 
Weight()
 const { 
return weight; }
   346   arma::Cube<ElemType>& 
Weight() { 
return weight; }
   376   size_t const& 
PoolSize()
 const { 
return poolSize; }
   379   template<
typename Archive>
   380   void serialize(Archive& ar, 
const uint32_t version);
   384   arma::Mat<ElemType> parameter;
   386   arma::Mat<ElemType> predictors;
   388   InitializationRuleType initializeRule;
   390   arma::Mat<ElemType> state;
   408   arma::Cube<ElemType> weight;
   410   DataType visibleBias;
   414   DataType preActivation;
   418   DataType visiblePenalty;
   420   DataType visibleMean;
   424   DataType spikeSamples;
   428   ElemType slabPenalty;
   432   arma::Mat<ElemType> hiddenReconstruction;
   434   arma::Mat<ElemType> visibleReconstruction;
   436   arma::Mat<ElemType> negativeSamples;
   438   arma::Mat<ElemType> negativeGradient;
   440   arma::Mat<ElemType> tempNegativeGradient;
   442   arma::Mat<ElemType> positiveGradient;
   444   arma::Mat<ElemType> gibbsTemporary;
   454 #include "rbm_impl.hpp"   455 #include "spike_slab_rbm_impl.hpp" DataType const  & HiddenBias() const
Return the hidden bias of the network. 
 
void Shuffle()
Shuffle the order of function visitation. 
 
DataType & VisibleBias()
Modify the visible bias of the network. 
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer. 
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM. 
 
Linear algebra utility functions, generally performed on matrices or vectors. 
 
size_t const  & VisibleSize() const
Get the visible size. 
 
void Gradient(const arma::Mat< ElemType > ¶meters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network. 
 
DataType & SpikeBias()
Modify the regularizer associated with spike variables. 
 
DataType::elem_type ElemType
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function. 
 
double Evaluate(const arma::Mat< ElemType > ¶meters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters. 
 
DataType & HiddenBias()
Modify the hidden bias of the network. 
 
arma::Cube< ElemType > & Weight()
Modify the weights of the network. 
 
The implementation of the RBM module. 
 
DataType const  & SpikeBias() const
Get the regularizer associated with spike variables. 
 
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution. 
 
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling. 
 
DataType const  & VisibleBias() const
Return the visible bias of the network. 
 
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables. 
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input. 
 
void serialize(Archive &ar, const uint32_t version)
Serialize the model. 
 
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network. 
 
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network. 
 
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by:  and vari...
 
ElemType const  & SlabPenalty() const
Get the regularizer associated with slab variables. 
 
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer. 
 
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
 
arma::Cube< ElemType > const  & Weight() const
Get the weights of the network. 
 
DataType const  & VisiblePenalty() const
Get the regularizer associated with visible variables. 
 
size_t NumSteps() const
Return the number of steps of Gibbs Sampling. 
 
size_t const  & HiddenSize() const
Get the hidden size. 
 
size_t const  & PoolSize() const
Get the pool size. 
 
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function. 
 
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data. 
 
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
 
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points). 
 
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.