em_fit.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_GMM_EM_FIT_HPP
15 #define MLPACK_METHODS_GMM_EM_FIT_HPP
16 
17 #include <mlpack/prereqs.hpp>
20 
21 // Default clustering mechanism.
23 // Default covariance matrix constraint.
25 
26 namespace mlpack {
27 namespace gmm {
28 
42 template<typename InitialClusteringType = kmeans::KMeans<>,
43  typename CovarianceConstraintPolicy = PositiveDefiniteConstraint,
44  typename Distribution = distribution::GaussianDistribution>
45 class EMFit
46 {
47  public:
64  EMFit(const size_t maxIterations = 300,
65  const double tolerance = 1e-10,
66  InitialClusteringType clusterer = InitialClusteringType(),
67  CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
68 
83  void Estimate(const arma::mat& observations,
84  std::vector<Distribution>& dists,
85  arma::vec& weights,
86  const bool useInitialModel = false);
87 
104  void Estimate(const arma::mat& observations,
105  const arma::vec& probabilities,
106  std::vector<Distribution>& dists,
107  arma::vec& weights,
108  const bool useInitialModel = false);
109 
111  const InitialClusteringType& Clusterer() const { return clusterer; }
113  InitialClusteringType& Clusterer() { return clusterer; }
114 
116  const CovarianceConstraintPolicy& Constraint() const { return constraint; }
118  CovarianceConstraintPolicy& Constraint() { return constraint; }
119 
121  size_t MaxIterations() const { return maxIterations; }
123  size_t& MaxIterations() { return maxIterations; }
124 
126  double Tolerance() const { return tolerance; }
128  double& Tolerance() { return tolerance; }
129 
131  template<typename Archive>
132  void serialize(Archive& ar, const uint32_t version);
133 
134  private:
145  void InitialClustering(
146  const arma::mat& observations,
147  std::vector<Distribution>& dists,
148  arma::vec& weights);
149 
160  double LogLikelihood(
161  const arma::mat& data,
162  const std::vector<Distribution>& dists,
163  const arma::vec& weights) const;
164 
175  void ArmadilloGMMWrapper(
176  const arma::mat& observations,
177  std::vector<Distribution>& dists,
178  arma::vec& weights,
179  const bool useInitialModel);
180 
182  size_t maxIterations;
184  double tolerance;
186  InitialClusteringType clusterer;
188  CovarianceConstraintPolicy constraint;
189 };
190 
191 } // namespace gmm
192 } // namespace mlpack
193 
194 // Include implementation.
195 #include "em_fit_impl.hpp"
196 
197 #endif
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:45
Linear algebra utility functions, generally performed on matrices or vectors.
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:121
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
Definition: em_fit.hpp:116
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Estimate(const arma::mat &observations, std::vector< Distribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
const InitialClusteringType & Clusterer() const
Get the clusterer.
Definition: em_fit.hpp:111
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:126
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:128
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:123
InitialClusteringType & Clusterer()
Modify the clusterer.
Definition: em_fit.hpp:113
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
Definition: em_fit.hpp:118
void serialize(Archive &ar, const uint32_t version)
Serialize the fitter.