MLPACK  1.0.8
em_fit.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
24 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
25 
26 #include <mlpack/core.hpp>
27 
28 // Default clustering mechanism.
30 // Default covariance matrix constraint.
32 
33 namespace mlpack {
34 namespace gmm {
35 
49 template<typename InitialClusteringType = kmeans::KMeans<>,
50  typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
51 class EMFit
52 {
53  public:
71  EMFit(const size_t maxIterations = 300,
72  const double tolerance = 1e-10,
73  InitialClusteringType clusterer = InitialClusteringType(),
74  CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
75 
91  void Estimate(const arma::mat& observations,
92  std::vector<arma::vec>& means,
93  std::vector<arma::mat>& covariances,
94  arma::vec& weights,
95  const bool useInitialModel = false);
96 
114  void Estimate(const arma::mat& observations,
115  const arma::vec& probabilities,
116  std::vector<arma::vec>& means,
117  std::vector<arma::mat>& covariances,
118  arma::vec& weights,
119  const bool useInitialModel = false);
120 
122  const InitialClusteringType& Clusterer() const { return clusterer; }
124  InitialClusteringType& Clusterer() { return clusterer; }
125 
127  const CovarianceConstraintPolicy& Constraint() const { return constraint; }
129  CovarianceConstraintPolicy& Constraint() { return constraint; }
130 
132  size_t MaxIterations() const { return maxIterations; }
134  size_t& MaxIterations() { return maxIterations; }
135 
137  double Tolerance() const { return tolerance; }
139  double& Tolerance() { return tolerance; }
140 
141  private:
152  void InitialClustering(const arma::mat& observations,
153  std::vector<arma::vec>& means,
154  std::vector<arma::mat>& covariances,
155  arma::vec& weights);
156 
167  double LogLikelihood(const arma::mat& data,
168  const std::vector<arma::vec>& means,
169  const std::vector<arma::mat>& covariances,
170  const arma::vec& weights) const;
171 
175  double tolerance;
177  InitialClusteringType clusterer;
179  CovarianceConstraintPolicy constraint;
180 };
181 
182 }; // namespace gmm
183 }; // namespace mlpack
184 
185 // Include implementation.
186 #include "em_fit_impl.hpp"
187 
188 #endif
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:51
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:139
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
Definition: em_fit.hpp:127
size_t maxIterations
Maximum iterations of EM algorithm.
Definition: em_fit.hpp:173
CovarianceConstraintPolicy constraint
Object which applies constraints to the covariance matrix.
Definition: em_fit.hpp:179
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:134
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:132
InitialClusteringType & Clusterer()
Modify the clusterer.
Definition: em_fit.hpp:124
void Estimate(const arma::mat &observations, std::vector< arma::vec > &means, std::vector< arma::mat > &covariances, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
Definition: em_fit.hpp:129
double LogLikelihood(const arma::mat &data, const std::vector< arma::vec > &means, const std::vector< arma::mat > &covariances, const arma::vec &weights) const
Calculate the log-likelihood of a model.
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:137
const InitialClusteringType & Clusterer() const
Get the clusterer.
Definition: em_fit.hpp:122
void InitialClustering(const arma::mat &observations, std::vector< arma::vec > &means, std::vector< arma::mat > &covariances, arma::vec &weights)
Run the clusterer, and then turn the cluster assignments into Gaussians.
InitialClusteringType clusterer
Object which will perform the clustering.
Definition: em_fit.hpp:177
double tolerance
Tolerance for convergence of EM.
Definition: em_fit.hpp:175