diff --git a/examples/mixture/plot_gmm.py b/examples/mixture/plot_gmm.py index b0ddabdbf628807b7d4c9148acc9eeb85e22e376..eeebe2fa0835a386b567f88b28cd99732379ab35 100644 --- a/examples/mixture/plot_gmm.py +++ b/examples/mixture/plot_gmm.py @@ -52,7 +52,7 @@ for i, (clf, title) in enumerate([(gmm, 'GMM'), (dpgmm, 'Dirichlet Process GMM')]): splot = pl.subplot(2, 1, 1 + i) Y_ = clf.predict(X) - for i, (mean, covar, color) in enumerate(zip(clf.means, clf.covars, + for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_, color_iter)): v, w = linalg.eigh(covar) u = w[0] / linalg.norm(w[0]) diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index 9e27987c405d23550fc6dfbea5da47920bea851e..aad139def85cd3ddb12752ea09dc885e5915499a 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -103,6 +103,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type): bound[:, k] -= 0.5 * precs[k] * (((X - means[k]) ** 2).sum(axis=-1) + n_features) elif covariance_type == 'tied': + # FIXME: suboptimal sqrt_cov = linalg.cholesky(precs) means = np.dot(means, sqrt_cov.T) X = np.dot(X, sqrt_cov.T) @@ -110,12 +111,6 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type): elif covariance_type == 'full': for k in xrange(n_components): bound[:, k] -= 0.5 * _sym_quad_form(X, means[k], precs[k]) - #d = X - means[k] - ## not: choleksy is useless here - #sqrt_cov = linalg.cholesky(precs[k]) - #d = np.dot(d, sqrt_cov.T) - #d **= 2 - #bound[:, k] -= 0.5 * d.sum(axis=-1) return bound @@ -170,13 +165,13 @@ class DPGMM(GMM): n_components : int Number of mixture components. - weights : array, shape (`n_components`,) + weights_ : array, shape (`n_components`,) Mixing weights for each mixture component. - means : array, shape (`n_components`, `n_features`) + means_ : array, shape (`n_components`, `n_features`) Mean parameters for each mixture component. - precisions : array + precisions_ : array Precision (inverse covariance) parameters for each mixture component. The shape depends on `covariance_type`:: (`n_components`, 'n_features') if 'spherical', @@ -215,7 +210,8 @@ class DPGMM(GMM): elif self.covariance_type == 'tied': return [self.precs_] * self.n_components elif self.covariance_type == 'spherical': - return [np.eye(self.means.shape[1]) * f for f in self.precs_] + # fixme: should not require self.means_ to be defined + return [np.eye(self.means_.shape[1]) * f for f in self.precs_] def _get_covars(self): return [linalg.pinv(c) for c in self._get_precisions()] @@ -521,15 +517,16 @@ class DPGMM(GMM): if init_params != '': self._initialize_gamma() - if 'm' in init_params or not hasattr(self, 'means'): + if 'm' in init_params or not hasattr(self, 'means_'): self.means_ = cluster.KMeans( k=self.n_components, random_state=self.random_state ).fit(X).cluster_centers_[::-1] - if 'w' in init_params or not hasattr(self, 'weights'): - self.weights = np.tile(1.0 / self.n_components, self.n_components) + if 'w' in init_params or not hasattr(self, 'weights_'): + self._set_weights(np.tile(1.0 / self.n_components, + self.n_components)) - if 'c' in init_params or not hasattr(self, 'covars'): + if 'c' in init_params or not hasattr(self, 'covars_'): if self.covariance_type == 'spherical': self.dof_ = np.ones(self.n_components) self.scale_ = np.ones(self.n_components) @@ -629,13 +626,13 @@ class VBGMM(DPGMM): n_components : int (read-only) Number of mixture components. - weights : array, shape (`n_components`,) + weights_ : array, shape (`n_components`,) Mixing weights for each mixture component. - means : array, shape (`n_components`, `n_features`) + means_ : array, shape (`n_components`, `n_features`) Mean parameters for each mixture component. - precisions : array + precisions_ : array Precision (inverse covariance) parameters for each mixture component. The shape depends on `covariance_type`: (`n_components`, 'n_features') if 'spherical', diff --git a/sklearn/mixture/gmm.py b/sklearn/mixture/gmm.py index 75c8dbc0018516dee53d59e9a6b5c0327948a62e..5388976323f61efcc77c43d51c1ffad9539ed23e 100644 --- a/sklearn/mixture/gmm.py +++ b/sklearn/mixture/gmm.py @@ -136,15 +136,14 @@ class GMM(BaseEstimator): Attributes ---------- -<<<<<<< HEAD covariance_type : string (read-only) String describing the type of covariance parameters used by the GMM. Must be one of 'spherical', 'tied', 'diag', 'full'. - weights : array, shape (`n_components`,) - Mixing weights for each mixture component. - means : array, shape (`n_components`, `n_features`) + log_weights_ : array, shape (`n_components`,) + log of mixing weights for each mixture component. + means_ : array, shape (`n_components`, `n_features`) Mean parameters for each mixture component. - covars : array + covars_ : array Covariance parameters for each mixture component. The shape depends on `covariance_type`: (n_components,) if 'spherical', @@ -155,13 +154,6 @@ class GMM(BaseEstimator): True when convergence was reached in fit(), False otherwise. - weights : property - this string will be replaced - - means : property - this string will be replaced - - cvtype : property - this string will be replaced - - covars : property - this string will be replaced See Also @@ -190,10 +182,10 @@ class GMM(BaseEstimator): GMM(covariance_type='diag', n_components=2) >>> np.round(g.weights, 2) array([ 0.75, 0.25]) - >>> np.round(g.means, 2) + >>> np.round(g.means_, 2) array([[ 10.05], [ 0.06]]) - >>> np.round(g.covars, 2) #doctest: +SKIP + >>> np.round(g.covars_, 2) #doctest: +SKIP array([[[ 1.02]], [[ 0.96]]]) >>> g.predict([[0], [2], [9], [10]]) @@ -220,7 +212,8 @@ class GMM(BaseEstimator): if not covariance_type in ['spherical', 'tied', 'diag', 'full']: raise ValueError('bad covariance_type: ' + str(covariance_type)) - self.weights = np.ones(self.n_components) / self.n_components + self.log_weights_ = - np.ones(self.n_components) * \ + np.log(self.n_components) # flag to indicate exit status of fit() method: converged (True) or # n_iter reached (False) @@ -254,12 +247,11 @@ class GMM(BaseEstimator): return [np.diag(cov) for cov in self.covars_] def _set_covars(self, covars): + """Provide values for covariance""" covars = np.asarray(covars) _validate_covars(covars, self._covariance_type, self.n_components) self.covars_ = covars - covars = property(_get_covars, _set_covars) - def _get_means(self): """Mean parameters for each mixture component array, shape ``(n_states, n_features)``.""" @@ -273,8 +265,6 @@ class GMM(BaseEstimator): '(n_components, n_features)') self.means_ = means.copy() - means = property(_get_means, _set_means) - def __repr__(self): return "GMM(covariance_type='%s', n_components=%s)" % \ (self._covariance_type, self.n_components) @@ -290,11 +280,8 @@ class GMM(BaseEstimator): raise ValueError('weights must have length n_components') if not np.allclose(np.sum(weights), 1.0): raise ValueError('weights must sum to 1.0') - self.log_weights_ = np.log(np.asarray(weights).copy()) - weights = property(_get_weights, _set_weights) - def eval(self, X): """Evaluate the model on data @@ -420,10 +407,10 @@ class GMM(BaseEstimator): if random_state is None: random_state = self.random_state random_state = check_random_state(random_state) - weight_pdf = self.weights + weight_pdf = np.exp(self.log_weights_) weight_cdf = np.cumsum(weight_pdf) - X = np.empty((n_samples, self.means.shape[1])) + X = np.empty((n_samples, self.means_.shape[1])) rand = random_state.rand(n_samples) # decide which component to use for each sample comps = weight_cdf.searchsorted(rand) @@ -490,16 +477,17 @@ class GMM(BaseEstimator): max_log_prob = - np.infty if n_init < 1: raise ValueError('GMM estimation requires at least one run') + for _ in range(n_init): - if 'm' in init_params or not hasattr(self, 'means'): + if 'm' in init_params or not hasattr(self, 'means_'): self.means_ = cluster.KMeans( k=self.n_components).fit(X).cluster_centers_ - if 'w' in init_params or not hasattr(self, 'weights'): - self.weights = np.tile(1.0 / self.n_components, - self.n_components) + if 'w' in init_params or not hasattr(self, 'weights_'): + self._set_weights(np.tile(1.0 / self.n_components, + self.n_components)) - if 'c' in init_params or not hasattr(self, 'covars'): + if 'c' in init_params or not hasattr(self, 'covars_'): cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1]) if not cv.shape: cv.shape = (1, 1) @@ -507,7 +495,7 @@ class GMM(BaseEstimator): _distribute_covar_matrix_to_match_covariance_type( cv, self._covariance_type, self.n_components) - # EM algorithm + # EM algorithms log_likelihood = [] # reset self.converged_ to False self.converged_ = False @@ -529,13 +517,13 @@ class GMM(BaseEstimator): if n_iter: if log_likelihood[-1] > max_log_prob: max_log_prob = log_likelihood[-1] - best_params = {'weights': self.weights, + best_params = {'weights': self._get_weights(), 'means': self.means_, 'covars': self.covars_} if n_iter: self.covars_ = best_params['covars'] self.means_ = best_params['means'] - self.weights = best_params['weights'] + self._set_weights(best_params['weights']) return self def _do_mstep(self, X, responsibilities, params, min_covar=0): @@ -543,9 +531,7 @@ class GMM(BaseEstimator): """ weights = responsibilities.sum(axis=0) weighted_X_sum = np.dot(responsibilities.T, X) - - inverse_weights = 1.0 / ( - weights[:, np.newaxis] + 10 * INF_EPS) + inverse_weights = 1.0 / (weights[:, np.newaxis] + 10 * INF_EPS) if 'w' in params: self.log_weights_ = np.log( @@ -561,7 +547,7 @@ class GMM(BaseEstimator): def _n_parameters(self): """Return the number of free parameters in the model.""" - ndim = self.means.shape[1] + ndim = self.means_.shape[1] if self._covariance_type == 'full': cov_params = self.n_components * ndim * (ndim + 1) / 2. elif self._covariance_type == 'diag': @@ -585,7 +571,7 @@ class GMM(BaseEstimator): ------- bic: float (the lower the better) """ - return (-2 * self.score(X).sum() + + return (- 2 * self.score(X).sum() + self._n_parameters() * np.log(X.shape[0])) def aic(self, X): @@ -600,7 +586,7 @@ class GMM(BaseEstimator): ------- aic: float (the lower the better) """ - return -2 * self.score(X).sum() + 2 * self._n_parameters() + return - 2 * self.score(X).sum() + 2 * self._n_parameters() ######################################################################### diff --git a/sklearn/mixture/tests/test_gmm.py b/sklearn/mixture/tests/test_gmm.py index b88c6a62b128f77e8e4c65c633762335bf31e8d9..f95e59128bf109f36bd23ea0f426bb7f733e23bf 100644 --- a/sklearn/mixture/tests/test_gmm.py +++ b/sklearn/mixture/tests/test_gmm.py @@ -110,26 +110,25 @@ def test_GMM_attributes(): assert g.n_components == n_components assert g.covariance_type == covariance_type - g.weights = weights - assert_array_almost_equal(g.weights, weights) - assert_raises(ValueError, g.__setattr__, 'weights', - 2 * weights) - assert_raises(ValueError, g.__setattr__, 'weights', []) - assert_raises(ValueError, g.__setattr__, 'weights', + g.weights_ = weights + assert_array_almost_equal(g.weights_, weights) + assert_raises(ValueError, g._set_weights, 2 * weights) + assert_raises(ValueError, g._set_weights, []) + assert_raises(ValueError, g._set_weights, np.zeros((n_components - 2, n_features))) - - g.means = means - assert_array_almost_equal(g.means, means) - assert_raises(ValueError, g.__setattr__, 'means', []) - assert_raises(ValueError, g.__setattr__, 'means', + + g.means_ = means + assert_array_almost_equal(g.means_, means) + assert_raises(ValueError, g._set_means, []) + assert_raises(ValueError, g._set_means, np.zeros((n_components - 2, n_features))) covars = (0.1 + 2 * rng.rand(n_components, n_features)) ** 2 g.covars_ = covars assert_array_almost_equal(g.covars_, covars) - assert_raises(ValueError, g.__setattr__, 'covars', []) - assert_raises(ValueError, g.__setattr__, 'covars', - np.zeros((n_components - 2, n_features))) + assert_raises(ValueError, g._set_covars, []) + assert_raises(ValueError, g._set_covars, + np.zeros((n_components - 2, n_features))) assert_raises(ValueError, mixture.GMM, n_components=20, covariance_type='badcovariance_type') @@ -160,13 +159,13 @@ class GMMTester(): covariance_type=self.covariance_type, random_state=rng) # Make sure the means are far apart so responsibilities.argmax() # picks the actual component used to generate the observations. - g.means = 20 * self.means + g.means_ = 20 * self.means g.covars_ = self.covars[self.covariance_type] - g.weights = self.weights + g._set_weights(self.weights) gaussidx = np.repeat(range(self.n_components), 5) n_samples = len(gaussidx) - X = rng.randn(n_samples, self.n_features) + g.means[gaussidx] + X = rng.randn(n_samples, self.n_features) + g.means_[gaussidx] ll, responsibilities = g.eval(X) @@ -181,9 +180,9 @@ class GMMTester(): covariance_type=self.covariance_type, random_state=rng) # Make sure the means are far apart so responsibilities.argmax() # picks the actual component used to generate the observations. - g.means = 20 * self.means + g.means_ = 20 * self.means g.covars_ = np.maximum(self.covars[self.covariance_type], 0.1) - g.weights = self.weights + g._set_weights(self.weights) samples = g.rvs(n) self.assertEquals(samples.shape, (n, self.n_features)) @@ -191,8 +190,8 @@ class GMMTester(): def test_train(self, params='wmc'): g = mixture.GMM(n_components=self.n_components, covariance_type=self.covariance_type) - g.weights = self.weights - g.means = self.means + g._set_weights(self.weights) + g.means_ = self.means g.covars_ = 20 * self.covars[self.covariance_type] # Create a training set by sampling from the predefined distribution. @@ -209,8 +208,7 @@ class GMMTester(): for iter in xrange(5): g.fit(X, n_iter=1, params=params, init_params='') trainll.append(self.score(g, X)) - g.fit(X, n_iter=10, params=params, init_params='') # finish - # fitting + g.fit(X, n_iter=10, params=params, init_params='') # finish fitting # Note that the log likelihood will sometimes decrease by a # very small amount after it has more or less converged due to @@ -222,7 +220,7 @@ class GMMTester(): delta_min > self.threshold, "The min nll increase is %f which is lower than the admissible" " threshold of %f, for model %s. The likelihoods are %s." - % (delta_min, self.threshold, self.covariance_type, trainll)) + % (delta_min, self.threshold, self.covariance_type, trainll)) def test_train_degenerate(self, params='wmc'): """ Train on degenerate data with 0 in some dimensions