From 7752db8b82fccd5f837ae2b6607107f152f766c0 Mon Sep 17 00:00:00 2001
From: Jaques Grobler <jaquesgrobler@gmail.com>
Date: Thu, 5 Apr 2012 16:19:35 +0200
Subject: [PATCH] DPGMM API updated, along with plot_gmm_sin example

---
 examples/mixture/plot_gmm_sin.py |  4 ++--
 sklearn/mixture/dpgmm.py         | 34 +++++++++++++++++++++++---------
 2 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py
index cae66038a2..6e727e11df 100644
--- a/examples/mixture/plot_gmm_sin.py
+++ b/examples/mixture/plot_gmm_sin.py
@@ -43,8 +43,8 @@ color_iter = itertools.cycle(['r', 'g', 'b', 'c', 'm'])
 for i, (clf, title) in enumerate([
         (mixture.GMM(n_components=10, covariance_type='full', n_iter=100), \
              "Expectation-maximization"),
-        (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01),
-         "Dirichlet Process,alpha=0.01"),
+        (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01,
+              n_iter=100), "Dirichlet Process,alpha=0.01"),
         (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100.),
          "Dirichlet Process,alpha=100.")
         ]):
diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py
index 3efc139243..a66f978aeb 100644
--- a/sklearn/mixture/dpgmm.py
+++ b/sklearn/mixture/dpgmm.py
@@ -10,12 +10,14 @@ Dirichlet Process Gaussian Mixture Models"""
 #
 
 import numpy as np
+import warnings
 from scipy.special import digamma as _digamma, gammaln as _gammaln
 from scipy import linalg
 from scipy.spatial.distance import cdist
 
 from ..utils import check_random_state
 from ..utils.extmath import norm
+from ..utils import deprecated
 from .. import cluster
 from .gmm import GMM
 
@@ -185,12 +187,14 @@ class DPGMM(GMM):
 
     def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
                  random_state=None, thresh=1e-2, verbose=False,
-                 min_covar=None):
+                 min_covar=None, n_iter=10, params='wmc', init_params='wmc'):
         self.alpha = alpha
         self.verbose = verbose
         super(DPGMM, self).__init__(n_components, covariance_type,
                                     random_state=random_state,
-                                    thresh=thresh, min_covar=min_covar)
+                                    thresh=thresh, min_covar=min_covar,
+                                    n_iter=n_iter, params=params, 
+                                    init_params=init_params)
 
     def _get_precisions(self):
         """Return precisions as a full matrix."""
@@ -456,7 +460,7 @@ class DPGMM(GMM):
 
         return c + self._logprior(z)
 
-    def fit(self, X, n_iter=10, params='wmc', init_params='wmc'):
+    def fit(self, X, **kwargs):
         """Estimate model parameters with the variational
         algorithm.
 
@@ -486,6 +490,18 @@ class DPGMM(GMM):
             'm' for means, and 'c' for covars.  Defaults to 'wmc'.
         """
         self.random_state = check_random_state(self.random_state)
+        if kwargs:
+            warnings.warn("Setting parameters in the 'fit' method is deprecated"
+                     "Set it on initialization instead.",
+                     DeprecationWarning)
+            # initialisations for in case the user still adds parameters to fit
+            # so things don't break
+            if 'n_iter' in kwargs:
+                self.n_iter =  kwargs['n_iter']
+            if 'params' in kwargs:
+                self.params = kwargs['params']            
+            if 'init_params' in kwargs:
+                self.init_params = kwargs['init_params']
 
         ## initialization step
         X = np.asarray(X)
@@ -499,18 +515,18 @@ class DPGMM(GMM):
         self._initial_bound = - 0.5 * n_features * np.log(2 * np.pi)
         self._initial_bound -= np.log(2 * np.pi * np.e)
 
-        if (init_params != '') or not hasattr(self, 'gamma_'):
+        if (self.init_params != '') or not hasattr(self, 'gamma_'):
             self._initialize_gamma()
 
-        if 'm' in init_params or not hasattr(self, 'means_'):
+        if 'm' in self.init_params or not hasattr(self, 'means_'):
             self.means_ = cluster.KMeans(
                 k=self.n_components,
                 random_state=self.random_state).fit(X).cluster_centers_[::-1]
 
-        if 'w' in init_params or not hasattr(self, 'weights_'):
+        if 'w' in self.init_params or not hasattr(self, 'weights_'):
             self.weights_ = np.tile(1.0 / self.n_components, self.n_components)
 
-        if 'c' in init_params or not hasattr(self, 'precs_'):
+        if 'c' in self.init_params or not hasattr(self, 'precs_'):
             if self._covariance_type == 'spherical':
                 self.dof_ = np.ones(self.n_components)
                 self.scale_ = np.ones(self.n_components)
@@ -553,7 +569,7 @@ class DPGMM(GMM):
         logprob = []
         # reset self.converged_ to False
         self.converged_ = False
-        for i in xrange(n_iter):
+        for i in xrange(self.n_iter):
             # Expectation step
             curr_logprob, z = self.eval(X)
             logprob.append(curr_logprob.sum() + self._logprior(z))
@@ -564,7 +580,7 @@ class DPGMM(GMM):
                 break
 
             # Maximization step
-            self._do_mstep(X, z, params)
+            self._do_mstep(X, z, self.params)
 
         return self
 
-- 
GitLab