mlpack  1.0.12
em_fit.hpp
Go to the documentation of this file.
1 
15 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
16 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
17 
18 #include <mlpack/core.hpp>
19 
20 // Default clustering mechanism.
22 // Default covariance matrix constraint.
24 
25 namespace mlpack {
26 namespace gmm {
27 
41 template<typename InitialClusteringType = kmeans::KMeans<>,
42  typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
43 class EMFit
44 {
45  public:
63  EMFit(const size_t maxIterations = 300,
64  const double tolerance = 1e-10,
65  InitialClusteringType clusterer = InitialClusteringType(),
66  CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
67 
83  void Estimate(const arma::mat& observations,
84  std::vector<arma::vec>& means,
85  std::vector<arma::mat>& covariances,
86  arma::vec& weights,
87  const bool useInitialModel = false);
88 
106  void Estimate(const arma::mat& observations,
107  const arma::vec& probabilities,
108  std::vector<arma::vec>& means,
109  std::vector<arma::mat>& covariances,
110  arma::vec& weights,
111  const bool useInitialModel = false);
112 
114  const InitialClusteringType& Clusterer() const { return clusterer; }
116  InitialClusteringType& Clusterer() { return clusterer; }
117 
119  const CovarianceConstraintPolicy& Constraint() const { return constraint; }
121  CovarianceConstraintPolicy& Constraint() { return constraint; }
122 
124  size_t MaxIterations() const { return maxIterations; }
126  size_t& MaxIterations() { return maxIterations; }
127 
129  double Tolerance() const { return tolerance; }
131  double& Tolerance() { return tolerance; }
132 
133  private:
144  void InitialClustering(const arma::mat& observations,
145  std::vector<arma::vec>& means,
146  std::vector<arma::mat>& covariances,
147  arma::vec& weights);
148 
159  double LogLikelihood(const arma::mat& data,
160  const std::vector<arma::vec>& means,
161  const std::vector<arma::mat>& covariances,
162  const arma::vec& weights) const;
163 
167  double tolerance;
169  InitialClusteringType clusterer;
171  CovarianceConstraintPolicy constraint;
172 };
173 
174 }; // namespace gmm
175 }; // namespace mlpack
176 
177 // Include implementation.
178 #include "em_fit_impl.hpp"
179 
180 #endif
This class contains methods which can fit a GMM to observations using the EM algorithm.
Definition: em_fit.hpp:43
double & Tolerance()
Modify the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:131
const CovarianceConstraintPolicy & Constraint() const
Get the covariance constraint policy class.
Definition: em_fit.hpp:119
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:23
size_t maxIterations
Maximum iterations of EM algorithm.
Definition: em_fit.hpp:165
CovarianceConstraintPolicy constraint
Object which applies constraints to the covariance matrix.
Definition: em_fit.hpp:171
size_t & MaxIterations()
Modify the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:126
size_t MaxIterations() const
Get the maximum number of iterations of the EM algorithm.
Definition: em_fit.hpp:124
InitialClusteringType & Clusterer()
Modify the clusterer.
Definition: em_fit.hpp:116
void Estimate(const arma::mat &observations, std::vector< arma::vec > &means, std::vector< arma::mat > &covariances, arma::vec &weights, const bool useInitialModel=false)
Fit the observations to a Gaussian mixture model (GMM) using the EM algorithm.
EMFit(const size_t maxIterations=300, const double tolerance=1e-10, InitialClusteringType clusterer=InitialClusteringType(), CovarianceConstraintPolicy constraint=CovarianceConstraintPolicy())
Construct the EMFit object, optionally passing an InitialClusteringType object (just in case it needs...
CovarianceConstraintPolicy & Constraint()
Modify the covariance constraint policy class.
Definition: em_fit.hpp:121
double LogLikelihood(const arma::mat &data, const std::vector< arma::vec > &means, const std::vector< arma::mat > &covariances, const arma::vec &weights) const
Calculate the log-likelihood of a model.
double Tolerance() const
Get the tolerance for the convergence of the EM algorithm.
Definition: em_fit.hpp:129
const InitialClusteringType & Clusterer() const
Get the clusterer.
Definition: em_fit.hpp:114
void InitialClustering(const arma::mat &observations, std::vector< arma::vec > &means, std::vector< arma::mat > &covariances, arma::vec &weights)
Run the clusterer, and then turn the cluster assignments into Gaussians.
InitialClusteringType clusterer
Object which will perform the clustering.
Definition: em_fit.hpp:169
double tolerance
Tolerance for convergence of EM.
Definition: em_fit.hpp:167