14 #ifndef MLPACK_METHODS_GMM_EM_FIT_HPP 15 #define MLPACK_METHODS_GMM_EM_FIT_HPP 42 template<
typename InitialClusteringType = kmeans::KMeans<>,
43 typename CovarianceConstra
intPolicy = PositiveDefiniteConstra
int,
44 typename Distribution = distribution::GaussianDistribution>
64 EMFit(
const size_t maxIterations = 300,
65 const double tolerance = 1e-10,
66 InitialClusteringType clusterer = InitialClusteringType(),
67 CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
83 void Estimate(
const arma::mat& observations,
84 std::vector<Distribution>& dists,
86 const bool useInitialModel =
false);
104 void Estimate(
const arma::mat& observations,
105 const arma::vec& probabilities,
106 std::vector<Distribution>& dists,
108 const bool useInitialModel =
false);
111 const InitialClusteringType&
Clusterer()
const {
return clusterer; }
113 InitialClusteringType&
Clusterer() {
return clusterer; }
116 const CovarianceConstraintPolicy&
Constraint()
const {
return constraint; }
118 CovarianceConstraintPolicy&
Constraint() {
return constraint; }
131 template<
typename Archive>
132 void serialize(Archive& ar,
const uint32_t version);
145 void InitialClustering(
146 const arma::mat& observations,
147 std::vector<Distribution>& dists,
160 double LogLikelihood(
161 const arma::mat& data,
162 const std::vector<Distribution>& dists,
163 const arma::vec& weights)
const;
175 void ArmadilloGMMWrapper(
176 const arma::mat& observations,
177 std::vector<Distribution>& dists,
179 const bool useInitialModel);
182 size_t maxIterations;
186 InitialClusteringType clusterer;
188 CovarianceConstraintPolicy constraint;
195 #include "em_fit_impl.hpp" This class contains methods which can fit a GMM to observations using the EM algorithm.
Linear algebra utility functions, generally performed on matrices or vectors.
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Estimate(const arma::mat &observations, std::vector< Distribution > &dists, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
const InitialClusteringType & Clusterer() const
Get the clusterer.
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
InitialClusteringType & Clusterer()
Modify the clusterer.
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
void serialize(Archive &ar, const uint32_t version)
Serialize the fitter.