From 7752db8b82fccd5f837ae2b6607107f152f766c0 Mon Sep 17 00:00:00 2001 From: Jaques Grobler <jaquesgrobler@gmail.com> Date: Thu, 5 Apr 2012 16:19:35 +0200 Subject: [PATCH] DPGMM API updated, along with plot_gmm_sin example --- examples/mixture/plot_gmm_sin.py | 4 ++-- sklearn/mixture/dpgmm.py | 34 +++++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py index cae66038a2..6e727e11df 100644 --- a/examples/mixture/plot_gmm_sin.py +++ b/examples/mixture/plot_gmm_sin.py @@ -43,8 +43,8 @@ color_iter = itertools.cycle(['r', 'g', 'b', 'c', 'm']) for i, (clf, title) in enumerate([ (mixture.GMM(n_components=10, covariance_type='full', n_iter=100), \ "Expectation-maximization"), - (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01), - "Dirichlet Process,alpha=0.01"), + (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01, + n_iter=100), "Dirichlet Process,alpha=0.01"), (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100.), "Dirichlet Process,alpha=100.") ]): diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index 3efc139243..a66f978aeb 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -10,12 +10,14 @@ Dirichlet Process Gaussian Mixture Models""" # import numpy as np +import warnings from scipy.special import digamma as _digamma, gammaln as _gammaln from scipy import linalg from scipy.spatial.distance import cdist from ..utils import check_random_state from ..utils.extmath import norm +from ..utils import deprecated from .. import cluster from .gmm import GMM @@ -185,12 +187,14 @@ class DPGMM(GMM): def __init__(self, n_components=1, covariance_type='diag', alpha=1.0, random_state=None, thresh=1e-2, verbose=False, - min_covar=None): + min_covar=None, n_iter=10, params='wmc', init_params='wmc'): self.alpha = alpha self.verbose = verbose super(DPGMM, self).__init__(n_components, covariance_type, random_state=random_state, - thresh=thresh, min_covar=min_covar) + thresh=thresh, min_covar=min_covar, + n_iter=n_iter, params=params, + init_params=init_params) def _get_precisions(self): """Return precisions as a full matrix.""" @@ -456,7 +460,7 @@ class DPGMM(GMM): return c + self._logprior(z) - def fit(self, X, n_iter=10, params='wmc', init_params='wmc'): + def fit(self, X, **kwargs): """Estimate model parameters with the variational algorithm. @@ -486,6 +490,18 @@ class DPGMM(GMM): 'm' for means, and 'c' for covars. Defaults to 'wmc'. """ self.random_state = check_random_state(self.random_state) + if kwargs: + warnings.warn("Setting parameters in the 'fit' method is deprecated" + "Set it on initialization instead.", + DeprecationWarning) + # initialisations for in case the user still adds parameters to fit + # so things don't break + if 'n_iter' in kwargs: + self.n_iter = kwargs['n_iter'] + if 'params' in kwargs: + self.params = kwargs['params'] + if 'init_params' in kwargs: + self.init_params = kwargs['init_params'] ## initialization step X = np.asarray(X) @@ -499,18 +515,18 @@ class DPGMM(GMM): self._initial_bound = - 0.5 * n_features * np.log(2 * np.pi) self._initial_bound -= np.log(2 * np.pi * np.e) - if (init_params != '') or not hasattr(self, 'gamma_'): + if (self.init_params != '') or not hasattr(self, 'gamma_'): self._initialize_gamma() - if 'm' in init_params or not hasattr(self, 'means_'): + if 'm' in self.init_params or not hasattr(self, 'means_'): self.means_ = cluster.KMeans( k=self.n_components, random_state=self.random_state).fit(X).cluster_centers_[::-1] - if 'w' in init_params or not hasattr(self, 'weights_'): + if 'w' in self.init_params or not hasattr(self, 'weights_'): self.weights_ = np.tile(1.0 / self.n_components, self.n_components) - if 'c' in init_params or not hasattr(self, 'precs_'): + if 'c' in self.init_params or not hasattr(self, 'precs_'): if self._covariance_type == 'spherical': self.dof_ = np.ones(self.n_components) self.scale_ = np.ones(self.n_components) @@ -553,7 +569,7 @@ class DPGMM(GMM): logprob = [] # reset self.converged_ to False self.converged_ = False - for i in xrange(n_iter): + for i in xrange(self.n_iter): # Expectation step curr_logprob, z = self.eval(X) logprob.append(curr_logprob.sum() + self._logprior(z)) @@ -564,7 +580,7 @@ class DPGMM(GMM): break # Maximization step - self._do_mstep(X, z, params) + self._do_mstep(X, z, self.params) return self -- GitLab