11 #ifndef MLPACK_METHODS_ANN_RBM_RBM_HPP 12 #define MLPACK_METHODS_ANN_RBM_RBM_HPP 34 typename InitializationRuleType,
35 typename DataType = arma::mat,
36 typename PolicyType = BinaryRBM
60 RBM(arma::Mat<ElemType> predictors,
61 InitializationRuleType initializeRule,
62 const size_t visibleSize,
63 const size_t hiddenSize,
64 const size_t batchSize = 1,
65 const size_t numSteps = 1,
66 const size_t negSteps = 1,
67 const size_t poolSize = 2,
68 const ElemType slabPenalty = 8,
69 const ElemType radius = 1,
70 const bool persistence =
false);
73 template<
typename Policy = PolicyType,
typename InputType = DataType>
74 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
78 template<
typename Policy = PolicyType,
typename InputType = DataType>
79 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
97 template<
typename OptimizerType,
typename... CallbackType>
98 double Train(OptimizerType& optimizer, CallbackType&&... callbacks);
108 double Evaluate(
const arma::Mat<ElemType>& parameters,
110 const size_t batchSize);
119 template<
typename Policy = PolicyType,
typename InputType = DataType>
120 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
double>::type
133 template<
typename Policy = PolicyType,
typename InputType = DataType>
134 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
144 template<
typename Policy = PolicyType,
typename InputType = DataType>
145 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
146 Phase(
const InputType& input, DataType& gradient);
154 template<
typename Policy = PolicyType,
typename InputType = DataType>
155 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
156 Phase(
const InputType& input, DataType& gradient);
165 template<
typename Policy = PolicyType,
typename InputType = DataType>
166 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
167 SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
179 template<
typename Policy = PolicyType,
typename InputType = DataType>
180 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
181 SampleHidden(
const arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
190 template<
typename Policy = PolicyType,
typename InputType = DataType>
191 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
192 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
204 template<
typename Policy = PolicyType,
typename InputType = DataType>
205 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
206 SampleVisible(arma::Mat<ElemType>& input, arma::Mat<ElemType>& output);
214 template<
typename Policy = PolicyType,
typename InputType = DataType>
215 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
226 template<
typename Policy = PolicyType,
typename InputType = DataType>
227 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
236 template<
typename Policy = PolicyType,
typename InputType = DataType>
237 typename std::enable_if<std::is_same<Policy, BinaryRBM>::value,
void>::type
238 HiddenMean(
const InputType& input, DataType& output);
250 template<
typename Policy = PolicyType,
typename InputType = DataType>
251 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
252 HiddenMean(
const InputType& input, DataType& output);
262 template<
typename Policy = PolicyType,
typename InputType = DataType>
263 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
264 SpikeMean(
const InputType& visible, DataType& spikeMean);
271 template<
typename Policy = PolicyType,
typename InputType = DataType>
272 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
273 SampleSpike(InputType& spikeMean, DataType& spike);
284 template<
typename Policy = PolicyType,
typename InputType = DataType>
285 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
286 SlabMean(
const DataType& visible, DataType& spike, DataType& slabMean);
298 template<
typename Policy = PolicyType,
typename InputType = DataType>
299 typename std::enable_if<std::is_same<Policy, SpikeSlabRBM>::value,
void>::type
300 SampleSlab(InputType& slabMean, DataType& slab);
309 void Gibbs(
const arma::Mat<ElemType>& input,
310 arma::Mat<ElemType>& output,
311 const size_t steps = SIZE_MAX);
321 void Gradient(
const arma::Mat<ElemType>& parameters,
323 arma::Mat<ElemType>& gradient,
324 const size_t batchSize);
339 const arma::Mat<ElemType>&
Parameters()
const {
return parameter; }
344 arma::Cube<ElemType>
const&
Weight()
const {
return weight; }
346 arma::Cube<ElemType>&
Weight() {
return weight; }
376 size_t const&
PoolSize()
const {
return poolSize; }
379 template<
typename Archive>
380 void serialize(Archive& ar,
const uint32_t version);
384 arma::Mat<ElemType> parameter;
386 arma::Mat<ElemType> predictors;
388 InitializationRuleType initializeRule;
390 arma::Mat<ElemType> state;
408 arma::Cube<ElemType> weight;
410 DataType visibleBias;
414 DataType preActivation;
418 DataType visiblePenalty;
420 DataType visibleMean;
424 DataType spikeSamples;
428 ElemType slabPenalty;
432 arma::Mat<ElemType> hiddenReconstruction;
434 arma::Mat<ElemType> visibleReconstruction;
436 arma::Mat<ElemType> negativeSamples;
438 arma::Mat<ElemType> negativeGradient;
440 arma::Mat<ElemType> tempNegativeGradient;
442 arma::Mat<ElemType> positiveGradient;
444 arma::Mat<ElemType> gibbsTemporary;
454 #include "rbm_impl.hpp" 455 #include "spike_slab_rbm_impl.hpp" DataType const & HiddenBias() const
Return the hidden bias of the network.
void Shuffle()
Shuffle the order of function visitation.
DataType & VisibleBias()
Modify the visible bias of the network.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type VisibleMean(InputType &input, DataType &output)
The function calculates the mean for the visible layer.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, double >::type FreeEnergy(const arma::Mat< ElemType > &input)
This function calculates the free energy of the BinaryRBM.
Linear algebra utility functions, generally performed on matrices or vectors.
size_t const & VisibleSize() const
Get the visible size.
void Gradient(const arma::Mat< ElemType > ¶meters, const size_t i, arma::Mat< ElemType > &gradient, const size_t batchSize)
Calculates the gradients for the RBM network.
DataType & SpikeBias()
Modify the regularizer associated with spike variables.
DataType::elem_type ElemType
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleVisible(arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the visible layer given the hidden layer using Bernoulli function.
double Evaluate(const arma::Mat< ElemType > ¶meters, const size_t i, const size_t batchSize)
Evaluate the RBM network with the given parameters.
DataType & HiddenBias()
Modify the hidden bias of the network.
arma::Cube< ElemType > & Weight()
Modify the weights of the network.
The implementation of the RBM module.
DataType const & SpikeBias() const
Get the regularizer associated with spike variables.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSpike(InputType &spikeMean, DataType &spike)
The function samples the spike function using Bernoulli distribution.
void Gibbs(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output, const size_t steps=SIZE_MAX)
This function does the k-step Gibbs Sampling.
DataType const & VisibleBias() const
Return the visible bias of the network.
DataType & VisiblePenalty()
Modify the regularizer associated with visible variables.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Phase(const InputType &input, DataType &gradient)
Calculates the gradient of the RBM network on the provided input.
void serialize(Archive &ar, const uint32_t version)
Serialize the model.
arma::Mat< ElemType > & Parameters()
Modify the parameters of the network.
const arma::Mat< ElemType > & Parameters() const
Return the parameters of the network.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SampleSlab(InputType &slabMean, DataType &slab)
The function samples from the Normal distribution of P(s|v, h), where the mean is given by: and vari...
ElemType const & SlabPenalty() const
Get the regularizer associated with slab variables.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type HiddenMean(const InputType &input, DataType &output)
The function calculates the mean for the hidden layer.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SpikeMean(const InputType &visible, DataType &spikeMean)
The function calculates the mean of the distribution P(h|v), where mean is given by: ...
arma::Cube< ElemType > const & Weight() const
Get the weights of the network.
DataType const & VisiblePenalty() const
Get the regularizer associated with visible variables.
size_t NumSteps() const
Return the number of steps of Gibbs Sampling.
size_t const & HiddenSize() const
Get the hidden size.
size_t const & PoolSize() const
Get the pool size.
std::enable_if< std::is_same< Policy, SpikeSlabRBM >::value, void >::type SlabMean(const DataType &visible, DataType &spike, DataType &slabMean)
The function calculates the mean of Normal distribution of P(s|v, h), where the mean is given by: ...
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type SampleHidden(const arma::Mat< ElemType > &input, arma::Mat< ElemType > &output)
This function samples the hidden layer given the visible layer using Bernoulli function.
double Train(OptimizerType &optimizer, CallbackType &&... callbacks)
Train the RBM on the given input data.
std::enable_if< std::is_same< Policy, BinaryRBM >::value, void >::type Reset()
size_t NumFunctions() const
Return the number of separable functions (the number of predictor points).
RBM(arma::Mat< ElemType > predictors, InitializationRuleType initializeRule, const size_t visibleSize, const size_t hiddenSize, const size_t batchSize=1, const size_t numSteps=1, const size_t negSteps=1, const size_t poolSize=2, const ElemType slabPenalty=8, const ElemType radius=1, const bool persistence=false)
Initialize all the parameters of the network using initializeRule.