bernoulli_distribution.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
13 #define MLPACK_METHODS_ANN_DISTRIBUTIONS_BERNOULLI_DISTRIBUTION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "../activation_functions/logistic_function.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
33 template <typename DataType = arma::mat>
35 {
36  public:
42 
63  BernoulliDistribution(const DataType& param,
64  const bool applyLogistic = true,
65  const double eps = 1e-10);
66 
72  double Probability(const DataType& observation) const
73  {
74  return std::exp(LogProbability(observation));
75  }
76 
82  double LogProbability(const DataType& observation) const;
83 
91  void LogProbBackward(const DataType& observation, DataType& output) const;
92 
99  DataType Sample() const;
100 
102  const DataType& Probability() const { return probability; }
103 
105  DataType& Probability() { return probability; }
106 
108  const DataType& Logits() const { return logits; }
109 
111  DataType& Logits() { return logits; }
112 
116  template<typename Archive>
117  void serialize(Archive& ar, const uint32_t /* version */)
118  {
119  // We just need to serialize each of the members.
120  ar(CEREAL_NVP(probability));
121  ar(CEREAL_NVP(logits));
122  ar(CEREAL_NVP(applyLogistic));
123  ar(CEREAL_NVP(eps));
124  }
125 
126  private:
128  DataType probability;
129 
132  DataType logits;
133 
135  bool applyLogistic;
136 
138  double eps;
139 }; // class BernoulliDistribution
140 
141 } // namespace ann
142 } // namespace mlpack
143 
144 // Include implementation.
145 #include "bernoulli_distribution_impl.hpp"
146 
147 #endif
const DataType & Probability() const
Return the probability matrix.
void LogProbBackward(const DataType &observation, DataType &output) const
Stores the gradient of the log probabilities of the observations in the output matrix.
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the distribution.
The core includes that mlpack expects; standard C++ includes and Armadillo.
double LogProbability(const DataType &observation) const
Return the log probabilities of the given matrix of observations.
const DataType & Logits() const
Return the logits matrix.
Multiple independent Bernoulli distributions.
double Probability(const DataType &observation) const
Return the probabilities of the given matrix of observations.
DataType & Logits()
Return a modifiable copy of the pre probability matrix.
DataType & Probability()
Return a modifiable copy of the probability matrix.
BernoulliDistribution()
Default constructor, which creates a Bernoulli distribution with zero dimension.
DataType Sample() const
Return a matrix of randomly generated samples according to the probability distributions defined by t...