diff --git a/doc/modules/mixture.rst b/doc/modules/mixture.rst index 6d6494b9bc4aafb7eea7b919a15bc54b9f27b38f..8249b217da5be522328b2d21e4346a05b1a7f1b2 100644 --- a/doc/modules/mixture.rst +++ b/doc/modules/mixture.rst @@ -19,8 +19,8 @@ components are also provided. :align: center :scale: 50% - **Two-component Gaussian mixture model:** *data points, and equi-probability surfaces of - the model.* + **Two-component Gaussian mixture model:** *data points, and equi-probability + surfaces of the model.* A Gaussian mixture model is a probabilistic model that assumes all the data points are generated from a mixture of a finite number of @@ -51,9 +51,9 @@ the :meth:`GaussianMixture.predict` method. sample belonging to the various Gaussians may be retrieved using the :meth:`GaussianMixture.predict_proba` method. -The :class:`GaussianMixture` comes with different options to constrain the covariance -of the difference classes estimated: spherical, diagonal, tied or full -covariance. +The :class:`GaussianMixture` comes with different options to constrain the +covariance of the difference classes estimated: spherical, diagonal, tied or +full covariance. .. figure:: ../auto_examples/mixture/images/plot_gmm_covariances_001.png :target: ../auto_examples/mixture/plot_gmm_covariances.html @@ -72,7 +72,7 @@ Pros and cons of class :class:`GaussianMixture` ----------------------------------------------- Pros -..... +.... :Speed: It is the fastest algorithm for learning mixture models diff --git a/examples/mixture/plot_gmm_covariances.py b/examples/mixture/plot_gmm_covariances.py index e3c8d8b68b43ae1daefb201948faa93c83d8a944..dbd5be50f93e1806f0b8268a20a90cf17488838d 100644 --- a/examples/mixture/plot_gmm_covariances.py +++ b/examples/mixture/plot_gmm_covariances.py @@ -47,14 +47,14 @@ colors = ['navy', 'turquoise', 'darkorange'] def make_ellipses(gmm, ax): for n, color in enumerate(colors): if gmm.covariance_type == 'full': - covars = gmm.covariances_[n][:2, :2] + covariances = gmm.covariances_[n][:2, :2] elif gmm.covariance_type == 'tied': - covars = gmm.covariances_[:2, :2] + covariances = gmm.covariances_[:2, :2] elif gmm.covariance_type == 'diag': - covars = np.diag(gmm.covariances_[n][:2]) + covariances = np.diag(gmm.covariances_[n][:2]) elif gmm.covariance_type == 'spherical': - covars = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n] - v, w = np.linalg.eigh(covars) + covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n] + v, w = np.linalg.eigh(covariances) u = w[0] / np.linalg.norm(w[0]) angle = np.arctan2(u[1], u[0]) angle = 180 * angle / np.pi # convert to degrees @@ -82,9 +82,9 @@ y_test = iris.target[test_index] n_classes = len(np.unique(y_train)) # Try GMMs using different types of covariances. -estimators = dict((covar_type, GaussianMixture(n_components=n_classes, - covariance_type=covar_type, max_iter=20)) - for covar_type in ['spherical', 'diag', 'tied', 'full']) +estimators = dict((cov_type, GaussianMixture(n_components=n_classes, + covariance_type=cov_type, max_iter=20, random_state=0)) + for cov_type in ['spherical', 'diag', 'tied', 'full']) n_estimators = len(estimators) diff --git a/examples/mixture/plot_gmm_selection.py b/examples/mixture/plot_gmm_selection.py index 747dc0d8a90c7adb8065277f4a135dd75f534ea1..3ccaba5262c0d354d42daeeac8d09a149bbf2093 100644 --- a/examples/mixture/plot_gmm_selection.py +++ b/examples/mixture/plot_gmm_selection.py @@ -75,9 +75,9 @@ spl.legend([b[0] for b in bars], cv_types) # Plot the winner splot = plt.subplot(2, 1, 2) Y_ = clf.predict(X) -for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covariances_, - color_iter)): - v, w = linalg.eigh(covar) +for i, (mean, cov, color) in enumerate(zip(clf.means_, clf.covariances_, + color_iter)): + v, w = linalg.eigh(cov) if not np.any(Y_ == i): continue plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color=color) diff --git a/sklearn/mixture/gaussian_mixture.py b/sklearn/mixture/gaussian_mixture.py index b19928dbcbdbcc16fb6dd5f692bef5f975896ef2..87e215ee2c9f134d223134f7651571d4340f11f5 100644 --- a/sklearn/mixture/gaussian_mixture.py +++ b/sklearn/mixture/gaussian_mixture.py @@ -36,14 +36,14 @@ def _check_weights(weights, n_components): _check_shape(weights, (n_components,), 'weights') # check range - if (any(np.less(weights, 0)) or - any(np.greater(weights, 1))): + if (any(np.less(weights, 0.)) or + any(np.greater(weights, 1.))): raise ValueError("The parameter 'weights' should be in the range " "[0, 1], but got max value %.5f, min value %.5f" % (np.min(weights), np.max(weights))) # check normalization - if not np.allclose(np.abs(1 - np.sum(weights)), 0.0): + if not np.allclose(np.abs(1. - np.sum(weights)), 0.): raise ValueError("The parameter 'weights' should be normalized, " "but got sum(weights) = %.5f" % np.sum(weights)) return weights @@ -72,33 +72,33 @@ def _check_means(means, n_components, n_features): return means -def _check_covariance_matrix(covariance, covariance_type): - """Check a covariance matrix is symmetric and positive-definite.""" - if (not np.allclose(covariance, covariance.T) or - np.any(np.less_equal(linalg.eigvalsh(covariance), .0))): - raise ValueError("'%s covariance' should be symmetric, " - "positive-definite" % covariance_type) +def _check_precision_positivity(precision, covariance_type): + """Check a precision vector is positive-definite.""" + if np.any(np.less_equal(precision, 0.0)): + raise ValueError("'%s precision' should be " + "positive" % covariance_type) -def _check_covariance_positivity(covariance, covariance_type): - """Check a covariance vector is positive-definite.""" - if np.any(np.less_equal(covariance, 0.0)): - raise ValueError("'%s covariance' should be " - "positive" % covariance_type) +def _check_precision_matrix(precision, covariance_type): + """Check a precision matrix is symmetric and positive-definite.""" + if not (np.allclose(precision, precision.T) and + np.all(linalg.eigvalsh(precision) > 0.)): + raise ValueError("'%s precision' should be symmetric, " + "positive-definite" % covariance_type) -def _check_covariances_full(covariances, covariance_type): - """Check the covariance matrices are symmetric and positive-definite.""" - for k, cov in enumerate(covariances): - _check_covariance_matrix(cov, covariance_type) +def _check_precisions_full(precisions, covariance_type): + """Check the precision matrices are symmetric and positive-definite.""" + for k, prec in enumerate(precisions): + prec = _check_precision_matrix(prec, covariance_type) -def _check_covariances(covariances, covariance_type, n_components, n_features): - """Validate user provided covariances. +def _check_precisions(precisions, covariance_type, n_components, n_features): + """Validate user provided precisions. Parameters ---------- - covariances : array-like, + precisions : array-like, 'full' : shape of (n_components, n_features, n_features) 'tied' : shape of (n_features, n_features) 'diag' : shape of (n_components, n_features) @@ -114,33 +114,37 @@ def _check_covariances(covariances, covariance_type, n_components, n_features): Returns ------- - covariances : array + precisions : array """ - covariances = check_array(covariances, dtype=[np.float64, np.float32], - ensure_2d=False, - allow_nd=covariance_type is 'full') - - covariances_shape = {'full': (n_components, n_features, n_features), - 'tied': (n_features, n_features), - 'diag': (n_components, n_features), - 'spherical': (n_components,)} - _check_shape(covariances, covariances_shape[covariance_type], - '%s covariance' % covariance_type) + precisions = check_array(precisions, dtype=[np.float64, np.float32], + ensure_2d=False, + allow_nd=covariance_type is 'full') - check_functions = {'full': _check_covariances_full, - 'tied': _check_covariance_matrix, - 'diag': _check_covariance_positivity, - 'spherical': _check_covariance_positivity} - check_functions[covariance_type](covariances, covariance_type) + precisions_shape = {'full': (n_components, n_features, n_features), + 'tied': (n_features, n_features), + 'diag': (n_components, n_features), + 'spherical': (n_components,)} + _check_shape(precisions, precisions_shape[covariance_type], + '%s precision' % covariance_type) - return covariances + _check_precisions = {'full': _check_precisions_full, + 'tied': _check_precision_matrix, + 'diag': _check_precision_positivity, + 'spherical': _check_precision_positivity} + _check_precisions[covariance_type](precisions, covariance_type) + return precisions ############################################################################### # Gaussian mixture parameters estimators (used by the M-Step) +ESTIMATE_PRECISION_ERROR_MESSAGE = ("The algorithm has diverged because of " + "too few samples per components. Try to " + "decrease the number of components, " + "or increase reg_covar.") -def _estimate_gaussian_covariance_full(resp, X, nk, means, reg_covar): - """Estimate the full covariance matrices. + +def _estimate_gaussian_precisions_cholesky_full(resp, X, nk, means, reg_covar): + """Estimate the full precision matrices. Parameters ---------- @@ -156,20 +160,27 @@ def _estimate_gaussian_covariance_full(resp, X, nk, means, reg_covar): Returns ------- - covariances : array, shape (n_components, n_features, n_features) + precisions_chol : array, shape (n_components, n_features, n_features) + The cholesky decomposition of the precision matrix. """ - n_features = X.shape[1] - n_components = means.shape[0] - covariances = np.empty((n_components, n_features, n_features)) + n_components, n_features = means.shape + precisions_chol = np.empty((n_components, n_features, n_features)) for k in range(n_components): diff = X - means[k] - covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k] - covariances[k].flat[::n_features + 1] += reg_covar - return covariances + covariance = np.dot(resp[:, k] * diff.T, diff) / nk[k] + covariance.flat[::n_features + 1] += reg_covar + try: + cov_chol = linalg.cholesky(covariance, lower=True) + except linalg.LinAlgError: + raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE) + precisions_chol[k] = linalg.solve_triangular(cov_chol, + np.eye(n_features), + lower=True).T + return precisions_chol -def _estimate_gaussian_covariance_tied(resp, X, nk, means, reg_covar): - """Estimate the tied covariance matrix. +def _estimate_gaussian_precisions_cholesky_tied(resp, X, nk, means, reg_covar): + """Estimate the tied precision matrix. Parameters ---------- @@ -185,18 +196,26 @@ def _estimate_gaussian_covariance_tied(resp, X, nk, means, reg_covar): Returns ------- - covariances : array, shape (n_features, n_features) + precisions_chol : array, shape (n_features, n_features) + The cholesky decomposition of the precision matrix. """ + n_samples, n_features = X.shape avg_X2 = np.dot(X.T, X) avg_means2 = np.dot(nk * means.T, means) covariances = avg_X2 - avg_means2 - covariances /= X.shape[0] + covariances /= n_samples covariances.flat[::len(covariances) + 1] += reg_covar - return covariances + try: + cov_chol = linalg.cholesky(covariances, lower=True) + except linalg.LinAlgError: + raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE) + precisions_chol = linalg.solve_triangular(cov_chol, np.eye(n_features), + lower=True).T + return precisions_chol -def _estimate_gaussian_covariance_diag(resp, X, nk, means, reg_covar): - """Estimate the diagonal covariance matrices. +def _estimate_gaussian_precisions_cholesky_diag(resp, X, nk, means, reg_covar): + """Estimate the diagonal precision matrices. Parameters ---------- @@ -212,16 +231,21 @@ def _estimate_gaussian_covariance_diag(resp, X, nk, means, reg_covar): Returns ------- - covariances : array, shape (n_components, n_features) + precisions_chol : array, shape (n_components, n_features) + The cholesky decomposition of the precision matrix. """ avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis] avg_means2 = means ** 2 avg_X_means = means * np.dot(resp.T, X) / nk[:, np.newaxis] - return avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar + covariances = avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar + if np.any(np.less_equal(covariances, 0.0)): + raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE) + return 1. / np.sqrt(covariances) -def _estimate_gaussian_covariance_spherical(resp, X, nk, means, reg_covar): - """Estimate the spherical covariance matrices. +def _estimate_gaussian_precisions_cholesky_spherical(resp, X, nk, means, + reg_covar): + """Estimate the spherical precision matrices. Parameters ---------- @@ -237,11 +261,16 @@ def _estimate_gaussian_covariance_spherical(resp, X, nk, means, reg_covar): Returns ------- - covariances : array, shape (n_components,) + precisions_chol : array, shape (n_components,) + The cholesky decomposition of the precision matrix. """ - covariances = _estimate_gaussian_covariance_diag(resp, X, nk, means, - reg_covar) - return covariances.mean(axis=1) + avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis] + avg_means2 = means ** 2 + avg_X_means = means * np.dot(resp.T, X) / nk[:, np.newaxis] + covariances = (avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar).mean(1) + if np.any(np.less_equal(covariances, 0.0)): + raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE) + return 1. / np.sqrt(covariances) def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): @@ -256,10 +285,10 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): The responsibilities for each data sample in X. reg_covar : float - The regularization added to each covariance matrices. + The regularization added to the diagonal of the covariance matrices. covariance_type : {'full', 'tied', 'diag', 'spherical'} - The type of covariance matrices. + The type of precision matrices. Returns ------- @@ -269,29 +298,25 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type): means : array, shape (n_components, n_features) The centers of the current components. - covariances : array - The sample covariances of the current components. - The shape depends of the covariance_type. + precisions_cholesky : array + The cholesky decomposition of sample precisions of the current + components. The shape depends of the covariance_type. """ - compute_covariance = { - "full": _estimate_gaussian_covariance_full, - "tied": _estimate_gaussian_covariance_tied, - "diag": _estimate_gaussian_covariance_diag, - "spherical": _estimate_gaussian_covariance_spherical} - nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps means = np.dot(resp.T, X) / nk[:, np.newaxis] - covariances = compute_covariance[covariance_type]( - resp, X, nk, means, reg_covar) - - return nk, means, covariances + precs_chol = {"full": _estimate_gaussian_precisions_cholesky_full, + "tied": _estimate_gaussian_precisions_cholesky_tied, + "diag": _estimate_gaussian_precisions_cholesky_diag, + "spherical": _estimate_gaussian_precisions_cholesky_spherical + }[covariance_type](resp, X, nk, means, reg_covar) + return nk, means, precs_chol ############################################################################### # Gaussian mixture probability estimators -def _estimate_log_gaussian_prob_full(X, means, covariances): - """Estimate the log Gaussian probability for 'full' covariance. +def _estimate_log_gaussian_prob_full(X, means, precisions_chol): + """Estimate the log Gaussian probability for 'full' precision. Parameters ---------- @@ -299,33 +324,26 @@ def _estimate_log_gaussian_prob_full(X, means, covariances): means : array-like, shape (n_components, n_features) - covariances : array-like, shape (n_components, n_features, n_features) + precisions_chol : array-like, shape (n_components, n_features, n_features) + Cholesky decompositions of the precision matrices. Returns ------- log_prob : array, shape (n_samples, n_components) """ n_samples, n_features = X.shape - n_components = means.shape[0] + n_components, _ = means.shape log_prob = np.empty((n_samples, n_components)) - for k, (mu, cov) in enumerate(zip(means, covariances)): - try: - cov_chol = linalg.cholesky(cov, lower=True) - except linalg.LinAlgError: - raise ValueError("The algorithm has diverged because of too " - "few samples per components. " - "Try to decrease the number of components, or " - "increase reg_covar.") - cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol))) - cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T, lower=True).T - log_prob[:, k] = - .5 * (n_features * np.log(2. * np.pi) + - cv_log_det + - np.sum(np.square(cv_sol), axis=1)) + for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)): + log_det = -2. * np.sum(np.log(np.diagonal(prec_chol))) + y = np.dot(X - mu, prec_chol) + log_prob[:, k] = -.5 * (n_features * np.log(2. * np.pi) + log_det + + np.sum(np.square(y), axis=1)) return log_prob -def _estimate_log_gaussian_prob_tied(X, means, covariances): - """Estimate the log Gaussian probability for 'tied' covariance. +def _estimate_log_gaussian_prob_tied(X, means, precision_chol): + """Estimate the log Gaussian probability for 'tied' precision. Parameters ---------- @@ -333,33 +351,26 @@ def _estimate_log_gaussian_prob_tied(X, means, covariances): means : array-like, shape (n_components, n_features) - covariances : array-like, shape (n_features, n_features) + precision_chol : array-like, shape (n_features, n_features) + Cholesky decomposition of the precision matrix. Returns ------- log_prob : array-like, shape (n_samples, n_components) """ n_samples, n_features = X.shape - n_components = means.shape[0] + n_components, _ = means.shape log_prob = np.empty((n_samples, n_components)) - try: - cov_chol = linalg.cholesky(covariances, lower=True) - except linalg.LinAlgError: - raise ValueError("The algorithm has diverged because of too " - "few samples per components. " - "Try to decrease the number of components, or " - "increase reg_covar.") - cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol))) + log_det = -2. * np.sum(np.log(np.diagonal(precision_chol))) for k, mu in enumerate(means): - cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T, - lower=True).T - log_prob[:, k] = np.sum(np.square(cv_sol), axis=1) - log_prob = - .5 * (n_features * np.log(2. * np.pi) + cv_log_det + log_prob) + y = np.dot(X - mu, precision_chol) + log_prob[:, k] = np.sum(np.square(y), axis=1) + log_prob = -.5 * (n_features * np.log(2. * np.pi) + log_det + log_prob) return log_prob -def _estimate_log_gaussian_prob_diag(X, means, covariances): - """Estimate the log Gaussian probability for 'diag' covariance. +def _estimate_log_gaussian_prob_diag(X, means, precisions_chol): + """Estimate the log Gaussian probability for 'diag' precision. Parameters ---------- @@ -367,28 +378,25 @@ def _estimate_log_gaussian_prob_diag(X, means, covariances): means : array-like, shape (n_components, n_features) - covariances : array-like, shape (n_components, n_features) + precisions_chol : array-like, shape (n_components, n_features) + Cholesky decompositions of the precision matrices. Returns ------- log_prob : array-like, shape (n_samples, n_components) """ - if np.any(np.less_equal(covariances, 0.0)): - raise ValueError("The algorithm has diverged because of too " - "few samples per components. " - "Try to decrease the number of components, or " - "increase reg_covar.") n_samples, n_features = X.shape - log_prob = - .5 * (n_features * np.log(2. * np.pi) + - np.sum(np.log(covariances), 1) + - np.sum((means ** 2 / covariances), 1) - - 2. * np.dot(X, (means / covariances).T) + - np.dot(X ** 2, (1. / covariances).T)) + precisions = precisions_chol ** 2 + log_prob = -.5 * (n_features * np.log(2. * np.pi) - + np.sum(np.log(precisions), 1) + + np.sum((means ** 2 * precisions), 1) - + 2. * np.dot(X, (means * precisions).T) + + np.dot(X ** 2, precisions.T)) return log_prob -def _estimate_log_gaussian_prob_spherical(X, means, covariances): - """Estimate the log Gaussian probability for 'spherical' covariance. +def _estimate_log_gaussian_prob_spherical(X, means, precisions_chol): + """Estimate the log Gaussian probability for 'spherical' precision. Parameters ---------- @@ -396,23 +404,20 @@ def _estimate_log_gaussian_prob_spherical(X, means, covariances): means : array-like, shape (n_components, n_features) - covariances : array-like, shape (n_components, ) + precisions_chol : array-like, shape (n_components, ) + Cholesky decompositions of the precision matrices. Returns ------- log_prob : array-like, shape (n_samples, n_components) """ - if np.any(np.less_equal(covariances, 0.0)): - raise ValueError("The algorithm has diverged because of too " - "few samples per components. " - "Try to decrease the number of components, or " - "increase reg_covar.") n_samples, n_features = X.shape - log_prob = - .5 * (n_features * np.log(2 * np.pi) + - n_features * np.log(covariances) + - np.sum(means ** 2, 1) / covariances - - 2 * np.dot(X, means.T / covariances) + - np.outer(np.sum(X ** 2, axis=1), 1. / covariances)) + precisions = precisions_chol ** 2 + log_prob = -.5 * (n_features * np.log(2 * np.pi) - + n_features * np.log(precisions) + + np.sum(means ** 2, 1) * precisions - + 2 * np.dot(X, means.T * precisions) + + np.outer(np.sum(X ** 2, axis=1), precisions)) return log_prob @@ -453,7 +458,7 @@ class GaussianMixture(BaseMixture): init_params : {'kmeans', 'random'}, defaults to 'kmeans'. The method used to initialize the weights, the means and the - covariances. + precisions. Must be one of:: 'kmeans' : responsibilities are initialized using kmeans. 'random' : responsibilities are initialized randomly. @@ -466,9 +471,10 @@ class GaussianMixture(BaseMixture): The user-provided initial means, defaults to None, If it None, means are initialized using the `init_params` method. - covariances_init: array-like, optional. - The user-provided initial covariances, defaults to None. - If it None, covariances are initialized using the 'init_params' method. + precisions_init: array-like, optional. + The user-provided initial precisions (inverse of the covariance + matrices), defaults to None. + If it None, precisions are initialized using the 'init_params' method. The shape depends on 'covariance_type':: (n_components,) if 'spherical', (n_features, n_features) if 'tied', @@ -493,11 +499,9 @@ class GaussianMixture(BaseMixture): ---------- weights_ : array, shape (n_components,) The weights of each mixture components. - `weights_` will not exist before a call to fit. means_ : array, shape (n_components, n_features) The mean of each mixture component. - `means_` will not exist before a call to fit. covariances_ : array The covariance of each mixture component. @@ -506,20 +510,43 @@ class GaussianMixture(BaseMixture): (n_features, n_features) if 'tied', (n_components, n_features) if 'diag', (n_components, n_features, n_features) if 'full' - `covariances_` will not exist before a call to fit. + + precisions_ : array + The precision matrices for each component in the mixture. A precision + matrix is the inverse of a covariance matrix. A covariance matrix is + symmetric positive definite so the mixture of Gaussian can be + equivalently parameterized by the precision matrices. Storing the + precision matrices instead of the covariance matrices makes it more + efficient to compute the log-likelihood of new samples at test time. + The shape depends on `covariance_type`:: + (n_components,) if 'spherical', + (n_features, n_features) if 'tied', + (n_components, n_features) if 'diag', + (n_components, n_features, n_features) if 'full' + + precisions_cholesky_ : array + The cholesky decomposition of the precision matrices of each mixture + component. A precision matrix is the inverse of a covariance matrix. + A covariance matrix is symmetric positive definite so the mixture of + Gaussian can be equivalently parameterized by the precision matrices. + Storing the precision matrices instead of the covariance matrices makes + it more efficient to compute the log-likelihood of new samples at test + time. The shape depends on `covariance_type`:: + (n_components,) if 'spherical', + (n_features, n_features) if 'tied', + (n_components, n_features) if 'diag', + (n_components, n_features, n_features) if 'full' converged_ : bool True when convergence was reached in fit(), False otherwise. - `converged_` will not exist before a call to fit. n_iter_ : int Number of step used by the best fit of EM to reach the convergence. - `n_iter_` will not exist before a call to fit. """ def __init__(self, n_components=1, covariance_type='full', tol=1e-3, reg_covar=1e-6, max_iter=100, n_init=1, init_params='kmeans', - weights_init=None, means_init=None, covariances_init=None, + weights_init=None, means_init=None, precisions_init=None, random_state=None, warm_start=False, verbose=0, verbose_interval=10): super(GaussianMixture, self).__init__( @@ -531,10 +558,11 @@ class GaussianMixture(BaseMixture): self.covariance_type = covariance_type self.weights_init = weights_init self.means_init = means_init - self.covariances_init = covariances_init + self.precisions_init = precisions_init def _check_parameters(self, X): """Check the Gaussian mixture parameters are well defined.""" + _, n_features = X.shape if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']: raise ValueError("Invalid value for 'covariance_type': %s " "'covariance_type' should be in " @@ -547,13 +575,13 @@ class GaussianMixture(BaseMixture): if self.means_init is not None: self.means_init = _check_means(self.means_init, - self.n_components, X.shape[1]) + self.n_components, n_features) - if self.covariances_init is not None: - self.covariances_init = _check_covariances(self.covariances_init, - self.covariance_type, - self.n_components, - X.shape[1]) + if self.precisions_init is not None: + self.precisions_init = _check_precisions(self.precisions_init, + self.covariance_type, + self.n_components, + n_features) def _initialize(self, X, resp): """Initialization of the Gaussian mixture parameters. @@ -564,60 +592,92 @@ class GaussianMixture(BaseMixture): resp : array-like, shape (n_samples, n_components) """ - weights, means, covariances = _estimate_gaussian_parameters( + n_samples, _ = X.shape + + weights, means, precisions_cholesky = _estimate_gaussian_parameters( X, resp, self.reg_covar, self.covariance_type) - weights /= X.shape[0] + weights /= n_samples self.weights_ = (weights if self.weights_init is None else self.weights_init) self.means_ = means if self.means_init is None else self.means_init - self.covariances_ = (covariances if self.covariances_init is None - else self.covariances_init) + + if self.precisions_init is None: + self.precisions_cholesky_ = precisions_cholesky + elif self.covariance_type is 'full': + self.precisions_cholesky_ = np.array( + [linalg.cholesky(prec_init, lower=True) + for prec_init in self.precisions_init]) + elif self.covariance_type is 'tied': + self.precisions_cholesky_ = linalg.cholesky(self.precisions_init, + lower=True) + else: + self.precisions_cholesky_ = self.precisions_init def _e_step(self, X): log_prob_norm, _, log_resp = self._estimate_log_prob_resp(X) return np.mean(log_prob_norm), np.exp(log_resp) def _m_step(self, X, resp): - self.weights_, self.means_, self.covariances_ = ( + self.weights_, self.means_, self.precisions_cholesky_ = ( _estimate_gaussian_parameters(X, resp, self.reg_covar, self.covariance_type)) self.weights_ /= X.shape[0] def _estimate_log_prob(self, X): - estimate_log_prob_functions = { - "full": _estimate_log_gaussian_prob_full, - "tied": _estimate_log_gaussian_prob_tied, - "diag": _estimate_log_gaussian_prob_diag, - "spherical": _estimate_log_gaussian_prob_spherical - } - return estimate_log_prob_functions[self.covariance_type]( - X, self.means_, self.covariances_) + return {"full": _estimate_log_gaussian_prob_full, + "tied": _estimate_log_gaussian_prob_tied, + "diag": _estimate_log_gaussian_prob_diag, + "spherical": _estimate_log_gaussian_prob_spherical + }[self.covariance_type](X, self.means_, + self.precisions_cholesky_) def _estimate_log_weights(self): return np.log(self.weights_) def _check_is_fitted(self): - check_is_fitted(self, ['weights_', 'means_', 'covariances_']) + check_is_fitted(self, ['weights_', 'means_', 'precisions_cholesky_']) def _get_parameters(self): - return self.weights_, self.means_, self.covariances_ + return self.weights_, self.means_, self.precisions_cholesky_ def _set_parameters(self, params): - self.weights_, self.means_, self.covariances_ = params + self.weights_, self.means_, self.precisions_cholesky_ = params + + # Attributes computation + _, n_features = self.means_.shape + + if self.covariance_type is 'full': + self.precisions_ = np.empty(self.precisions_cholesky_.shape) + self.covariances_ = np.empty(self.precisions_cholesky_.shape) + for k, prec_chol in enumerate(self.precisions_cholesky_): + self.precisions_[k] = np.dot(prec_chol, prec_chol.T) + cov_chol = linalg.solve_triangular(prec_chol, + np.eye(n_features)) + self.covariances_[k] = np.dot(cov_chol.T, cov_chol) + + elif self.covariance_type is 'tied': + self.precisions_ = np.dot(self.precisions_cholesky_, + self.precisions_cholesky_.T) + cov_chol = linalg.solve_triangular(self.precisions_cholesky_, + np.eye(n_features)) + self.covariances_ = np.dot(cov_chol.T, cov_chol) + else: + self.precisions_ = self.precisions_cholesky_ ** 2 + self.covariances_ = 1. / self.precisions_ def _n_parameters(self): """Return the number of free parameters in the model.""" - ndim = self.means_.shape[1] + _, n_features = self.means_.shape if self.covariance_type == 'full': - cov_params = self.n_components * ndim * (ndim + 1) / 2. + cov_params = self.n_components * n_features * (n_features + 1) / 2. elif self.covariance_type == 'diag': - cov_params = self.n_components * ndim + cov_params = self.n_components * n_features elif self.covariance_type == 'tied': - cov_params = ndim * (ndim + 1) / 2. + cov_params = n_features * (n_features + 1) / 2. elif self.covariance_type == 'spherical': cov_params = self.n_components - mean_params = ndim * self.n_components + mean_params = n_features * self.n_components return int(cov_params + mean_params + self.n_components - 1) def bic(self, X): diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py index 64cdbe54c9f30992855612f551f573df54c07776..8e3e5516d7d279f4f42dbcb0b76a65fac223d405 100644 --- a/sklearn/mixture/tests/test_gaussian_mixture.py +++ b/sklearn/mixture/tests/test_gaussian_mixture.py @@ -3,18 +3,18 @@ import warnings import numpy as np -from scipy import stats +from scipy import stats, linalg from sklearn.covariance import EmpiricalCovariance from sklearn.datasets.samples_generator import make_spd_matrix from sklearn.externals.six.moves import cStringIO as StringIO from sklearn.metrics.cluster import adjusted_rand_score from sklearn.mixture.gaussian_mixture import GaussianMixture -from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_diag -from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_full from sklearn.mixture.gaussian_mixture import ( - _estimate_gaussian_covariance_spherical) -from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_tied + _estimate_gaussian_precisions_cholesky_full, + _estimate_gaussian_precisions_cholesky_tied, + _estimate_gaussian_precisions_cholesky_diag, + _estimate_gaussian_precisions_cholesky_spherical) from sklearn.exceptions import ConvergenceWarning, NotFittedError from sklearn.utils.extmath import fast_logdet from sklearn.utils.testing import assert_allclose @@ -32,28 +32,28 @@ from sklearn.utils.testing import assert_warns_message COVARIANCE_TYPE = ['full', 'tied', 'diag', 'spherical'] -def generate_data(n_samples, n_features, weights, means, covariances, +def generate_data(n_samples, n_features, weights, means, precisions, covariance_type): rng = np.random.RandomState(0) X = [] if covariance_type == 'spherical': for _, (w, m, c) in enumerate(zip(weights, means, - covariances['spherical'])): + precisions['spherical'])): X.append(rng.multivariate_normal(m, c * np.eye(n_features), int(np.round(w * n_samples)))) if covariance_type == 'diag': for _, (w, m, c) in enumerate(zip(weights, means, - covariances['diag'])): + precisions['diag'])): X.append(rng.multivariate_normal(m, np.diag(c), int(np.round(w * n_samples)))) if covariance_type == 'tied': for _, (w, m) in enumerate(zip(weights, means)): - X.append(rng.multivariate_normal(m, covariances['tied'], + X.append(rng.multivariate_normal(m, precisions['tied'], int(np.round(w * n_samples)))) if covariance_type == 'full': for _, (w, m, c) in enumerate(zip(weights, means, - covariances['full'])): + precisions['full'])): X.append(rng.multivariate_normal(m, c, int(np.round(w * n_samples)))) @@ -75,13 +75,19 @@ class RandomData(object): 'spherical': .5 + rng.rand(n_components), 'diag': (.5 + rng.rand(n_components, n_features)) ** 2, 'tied': make_spd_matrix(n_features, random_state=rng), - 'full': np.array([make_spd_matrix( - n_features, random_state=rng) * .5 + 'full': np.array([ + make_spd_matrix(n_features, random_state=rng) * .5 for _ in range(n_components)])} + self.precisions = { + 'spherical': 1. / self.covariances['spherical'], + 'diag': 1. / self.covariances['diag'], + 'tied': linalg.inv(self.covariances['tied']), + 'full': np.array([linalg.inv(covariance) + for covariance in self.covariances['full']])} self.X = dict(zip(COVARIANCE_TYPE, [generate_data( n_samples, n_features, self.weights, self.means, self.covariances, - cov_type) for cov_type in COVARIANCE_TYPE])) + covar_type) for covar_type in COVARIANCE_TYPE])) self.Y = np.hstack([k * np.ones(int(np.round(w * n_samples))) for k, w in enumerate(self.weights)]) @@ -198,9 +204,8 @@ def test_check_weights(): g.weights_init = weights_bad_shape assert_raise_message(ValueError, "The parameter 'weights' should have the shape of " - "(%d,), " - "but got %s" % (n_components, - str(weights_bad_shape.shape)), + "(%d,), but got %s" % + (n_components, str(weights_bad_shape.shape)), g.fit, X) # Check bad range @@ -253,27 +258,27 @@ def test_check_means(): assert_array_equal(means, g.means_init) -def test_check_covariances(): +def test_check_precisions(): rng = np.random.RandomState(0) rand_data = RandomData(rng) n_components, n_features = rand_data.n_components, rand_data.n_features - # Define the bad covariances for each covariance_type - covariances_bad_shape = { - 'full': rng.rand(n_components + 1, n_features, n_features), - 'tied': rng.rand(n_features + 1, n_features + 1), - 'diag': rng.rand(n_components + 1, n_features), - 'spherical': rng.rand(n_components + 1)} - - # Define not positive-definite covariances - covariances_not_pos = rng.rand(n_components, n_features, n_features) - covariances_not_pos[0] = np.eye(n_features) - covariances_not_pos[0, 0, 0] = -1. - - covariances_not_positive = { - 'full': covariances_not_pos, - 'tied': covariances_not_pos[0], + # Define the bad precisions for each covariance_type + precisions_bad_shape = { + 'full': np.ones((n_components + 1, n_features, n_features)), + 'tied': np.ones((n_features + 1, n_features + 1)), + 'diag': np.ones((n_components + 1, n_features)), + 'spherical': np.ones((n_components + 1))} + + # Define not positive-definite precisions + precisions_not_pos = np.ones((n_components, n_features, n_features)) + precisions_not_pos[0] = np.eye(n_features) + precisions_not_pos[0, 0, 0] = -1. + + precisions_not_positive = { + 'full': precisions_not_pos, + 'tied': precisions_not_pos[0], 'diag': -1. * np.ones((n_components, n_features)), 'spherical': -1. * np.ones(n_components)} @@ -283,33 +288,35 @@ def test_check_covariances(): 'diag': 'positive', 'spherical': 'positive'} - for cov_type in ['full', 'tied', 'diag', 'spherical']: - X = rand_data.X[cov_type] + for covar_type in COVARIANCE_TYPE: + X = RandomData(rng).X[covar_type] g = GaussianMixture(n_components=n_components, - covariance_type=cov_type) + covariance_type=covar_type, + random_state=rng) - # Check covariance with bad shapes - g.covariances_init = covariances_bad_shape[cov_type] + # Check precisions with bad shapes + g.precisions_init = precisions_bad_shape[covar_type] assert_raise_message(ValueError, - "The parameter '%s covariance' should have " - "the shape of" % cov_type, + "The parameter '%s precision' should have " + "the shape of" % covar_type, g.fit, X) - # Check not positive covariances - g.covariances_init = covariances_not_positive[cov_type] + # Check not positive precisions + g.precisions_init = precisions_not_positive[covar_type] assert_raise_message(ValueError, - "'%s covariance' should be %s" - % (cov_type, not_positive_errors[cov_type]), + "'%s precision' should be %s" + % (covar_type, not_positive_errors[covar_type]), g.fit, X) - # Check the correct init of covariances_init - g.covariances_init = rand_data.covariances[cov_type] + # Check the correct init of precisions_init + g.precisions_init = rand_data.precisions[covar_type] g.fit(X) - assert_array_equal(rand_data.covariances[cov_type], g.covariances_init) + assert_array_equal(rand_data.precisions[covar_type], g.precisions_init) def test_suffstat_sk_full(): - # compare the EmpiricalCovariance.covariance fitted on X*sqrt(resp) + # compare the precision matrix compute from the + # EmpiricalCovariance.covariance fitted on X*sqrt(resp) # with _sufficient_sk_full, n_components=1 rng = np.random.RandomState(0) n_samples, n_features = 500, 2 @@ -320,21 +327,25 @@ def test_suffstat_sk_full(): X_resp = np.sqrt(resp) * X nk = np.array([n_samples]) xk = np.zeros((1, n_features)) - covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + precs_pred = _estimate_gaussian_precisions_cholesky_full(resp, X, + nk, xk, 0) + covars_pred = linalg.inv(np.dot(precs_pred[0], precs_pred[0].T)) ecov = EmpiricalCovariance(assume_centered=True) ecov.fit(X_resp) - assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0) - assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0) + assert_almost_equal(ecov.error_norm(covars_pred, norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(covars_pred, norm='spectral'), 0) # special case 2, assuming resp are all ones resp = np.ones((n_samples, 1)) nk = np.array([n_samples]) - xk = X.mean().reshape((1, -1)) - covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + xk = X.mean(axis=0).reshape((1, -1)) + precs_pred = _estimate_gaussian_precisions_cholesky_full(resp, X, + nk, xk, 0) + covars_pred = linalg.inv(np.dot(precs_pred[0], precs_pred[0].T)) ecov = EmpiricalCovariance(assume_centered=False) ecov.fit(X) - assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0) - assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0) + assert_almost_equal(ecov.error_norm(covars_pred, norm='frobenius'), 0) + assert_almost_equal(ecov.error_norm(covars_pred, norm='spectral'), 0) def test_suffstat_sk_tied(): @@ -347,11 +358,18 @@ def test_suffstat_sk_tied(): X = rng.rand(n_samples, n_features) nk = resp.sum(axis=0) xk = np.dot(resp.T, X) / nk[:, np.newaxis] - covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) + + precs_pred_full = _estimate_gaussian_precisions_cholesky_full(resp, X, + nk, xk, 0) + covars_pred_full = [linalg.inv(np.dot(precision_chol, precision_chol.T)) + for precision_chol in precs_pred_full] covars_pred_full = np.sum(nk[:, np.newaxis, np.newaxis] * covars_pred_full, 0) / n_samples - covars_pred_tied = _estimate_gaussian_covariance_tied(resp, X, nk, xk, 0) + precs_pred_tied = _estimate_gaussian_precisions_cholesky_tied(resp, X, + nk, xk, 0) + covars_pred_tied = linalg.inv(np.dot(precs_pred_tied, precs_pred_tied.T)) + ecov = EmpiricalCovariance() ecov.covariance_ = covars_pred_full assert_almost_equal(ecov.error_norm(covars_pred_tied, norm='frobenius'), 0) @@ -368,14 +386,19 @@ def test_suffstat_sk_diag(): X = rng.rand(n_samples, n_features) nk = resp.sum(axis=0) xk = np.dot(resp.T, X) / nk[:, np.newaxis] - covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0) - covars_pred_full = np.array([np.diag(np.diag(d)) for d in - covars_pred_full]) - covars_pred_diag = _estimate_gaussian_covariance_diag(resp, X, nk, xk, 0) - covars_pred_diag = np.array([np.diag(d) for d in covars_pred_diag]) + precs_pred_full = _estimate_gaussian_precisions_cholesky_full(resp, X, + nk, xk, 0) + covars_pred_full = [linalg.inv(np.dot(precision_chol, precision_chol.T)) + for precision_chol in precs_pred_full] + + precs_pred_diag = _estimate_gaussian_precisions_cholesky_diag(resp, X, + nk, xk, 0) + covars_pred_diag = np.array([np.diag(1. / d) ** 2 + for d in precs_pred_diag]) + ecov = EmpiricalCovariance() for (cov_full, cov_diag) in zip(covars_pred_full, covars_pred_diag): - ecov.covariance_ = cov_full + ecov.covariance_ = np.diag(np.diag(cov_full)) assert_almost_equal(ecov.error_norm(cov_diag, norm='frobenius'), 0) assert_almost_equal(ecov.error_norm(cov_diag, norm='spectral'), 0) @@ -391,11 +414,11 @@ def test_gaussian_suffstat_sk_spherical(): resp = np.ones((n_samples, 1)) nk = np.array([n_samples]) xk = X.mean() - covars_pred_spherical = _estimate_gaussian_covariance_spherical(resp, X, - nk, xk, 0) - covars_pred_spherical2 = (np.dot(X.flatten().T, X.flatten()) / - (n_features * n_samples)) - assert_almost_equal(covars_pred_spherical, covars_pred_spherical2) + precs_pred_spherical = _estimate_gaussian_precisions_cholesky_spherical( + resp, X, nk, xk, 0) + covars_pred_spherical = (np.dot(X.flatten().T, X.flatten()) / + (n_features * n_samples)) + assert_almost_equal(1. / precs_pred_spherical ** 2, covars_pred_spherical) def _naive_lmvnpdf_diag(X, means, covars): @@ -426,29 +449,33 @@ def test_gaussian_mixture_log_probabilities(): log_prob_naive = _naive_lmvnpdf_diag(X, means, covars_diag) # full covariances - covars_full = np.array([np.diag(x) for x in covars_diag]) + precs_full = np.array([np.diag(1. / np.sqrt(x)) for x in covars_diag]) - log_prob = _estimate_log_gaussian_prob_full(X, means, covars_full) + log_prob = _estimate_log_gaussian_prob_full(X, means, precs_full) assert_array_almost_equal(log_prob, log_prob_naive) # diag covariances - log_prob = _estimate_log_gaussian_prob_diag(X, means, covars_diag) + precs_chol_diag = 1. / np.sqrt(covars_diag) + log_prob = _estimate_log_gaussian_prob_diag(X, means, precs_chol_diag) assert_array_almost_equal(log_prob, log_prob_naive) # tied - covars_tied = covars_full.mean(axis=0) + covars_tied = np.array([x for x in covars_diag]).mean(axis=0) + precs_tied = np.diag(np.sqrt(1. / covars_tied)) + log_prob_naive = _naive_lmvnpdf_diag(X, means, - [np.diag(covars_tied)] * n_components) - log_prob = _estimate_log_gaussian_prob_tied(X, means, covars_tied) + [covars_tied] * n_components) + log_prob = _estimate_log_gaussian_prob_tied(X, means, precs_tied) + assert_array_almost_equal(log_prob, log_prob_naive) # spherical covars_spherical = covars_diag.mean(axis=1) + precs_spherical = 1. / np.sqrt(covars_diag.mean(axis=1)) log_prob_naive = _naive_lmvnpdf_diag(X, means, [[k] * n_features for k in covars_spherical]) - log_prob = _estimate_log_gaussian_prob_spherical(X, means, - covars_spherical) + log_prob = _estimate_log_gaussian_prob_spherical(X, means, precs_spherical) assert_array_almost_equal(log_prob, log_prob_naive) # skip tests on weighted_log_probabilities, log_weights @@ -463,33 +490,33 @@ def test_gaussian_mixture_estimate_log_prob_resp(): n_components = rand_data.n_components X = rng.rand(n_samples, n_features) - for cov_type in COVARIANCE_TYPE: + for covar_type in COVARIANCE_TYPE: weights = rand_data.weights means = rand_data.means - covariances = rand_data.covariances[cov_type] + precisions = rand_data.precisions[covar_type] g = GaussianMixture(n_components=n_components, random_state=rng, weights_init=weights, means_init=means, - covariances_init=covariances, - covariance_type=cov_type) + precisions_init=precisions, + covariance_type=covar_type) g.fit(X) resp = g.predict_proba(X) assert_array_almost_equal(resp.sum(axis=1), np.ones(n_samples)) assert_array_equal(g.weights_init, weights) assert_array_equal(g.means_init, means) - assert_array_equal(g.covariances_init, covariances) + assert_array_equal(g.precisions_init, precisions) def test_gaussian_mixture_predict_predict_proba(): rng = np.random.RandomState(0) rand_data = RandomData(rng) - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] Y = rand_data.Y g = GaussianMixture(n_components=rand_data.n_components, random_state=rng, weights_init=rand_data.weights, means_init=rand_data.means, - covariances_init=rand_data.covariances[cov_type], - covariance_type=cov_type) + precisions_init=rand_data.precisions[covar_type], + covariance_type=covar_type) # Check a warning message arrive if we don't do fit assert_raise_message(NotFittedError, @@ -511,12 +538,13 @@ def test_gaussian_mixture_fit(): n_features = rand_data.n_features n_components = rand_data.n_components - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] - g = GaussianMixture(n_components=n_components, n_init=20, max_iter=100, + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] + g = GaussianMixture(n_components=n_components, n_init=20, reg_covar=0, random_state=rng, - covariance_type=cov_type) + covariance_type=covar_type) g.fit(X) + # needs more data to pass the test with rtol=1e-7 assert_allclose(np.sort(g.weights_), np.sort(rand_data.weights), rtol=0.1, atol=1e-2) @@ -526,28 +554,29 @@ def test_gaussian_mixture_fit(): assert_allclose(g.means_[arg_idx1], rand_data.means[arg_idx2], rtol=0.1, atol=1e-2) - if cov_type == 'spherical': - cov_pred = np.array([np.eye(n_features) * c - for c in g.covariances_]) - cov_test = np.array([np.eye(n_features) * c for c in - rand_data.covariances['spherical']]) - elif cov_type == 'diag': - cov_pred = np.array([np.diag(d) for d in g.covariances_]) - cov_test = np.array([np.diag(d) for d in - rand_data.covariances['diag']]) - elif cov_type == 'tied': - cov_pred = np.array([g.covariances_] * n_components) - cov_test = np.array([rand_data.covariances['tied']] * n_components) - elif cov_type == 'full': - cov_pred = g.covariances_ - cov_test = rand_data.covariances['full'] - arg_idx1 = np.trace(cov_pred, axis1=1, axis2=2).argsort() - arg_idx2 = np.trace(cov_test, axis1=1, axis2=2).argsort() + if covar_type == 'full': + prec_pred = g.precisions_ + prec_test = rand_data.precisions['full'] + elif covar_type == 'tied': + prec_pred = np.array([g.precisions_] * n_components) + prec_test = np.array([rand_data.precisions['tied']] * n_components) + elif covar_type == 'spherical': + prec_pred = np.array([np.eye(n_features) * c + for c in g.precisions_]) + prec_test = np.array([np.eye(n_features) * c for c in + rand_data.precisions['spherical']]) + elif covar_type == 'diag': + prec_pred = np.array([np.diag(d) for d in g.precisions_]) + prec_test = np.array([np.diag(d) for d in + rand_data.precisions['diag']]) + + arg_idx1 = np.trace(prec_pred, axis1=1, axis2=2).argsort() + arg_idx2 = np.trace(prec_test, axis1=1, axis2=2).argsort() for k, h in zip(arg_idx1, arg_idx2): ecov = EmpiricalCovariance() - ecov.covariance_ = cov_test[h] + ecov.covariance_ = prec_test[h] # the accuracy depends on the number of data and randomness, rng - assert_allclose(ecov.error_norm(cov_pred[k]), 0, atol=0.1) + assert_allclose(ecov.error_norm(prec_pred[k]), 0, atol=0.1) def test_gaussian_mixture_fit_best_params(): @@ -555,19 +584,18 @@ def test_gaussian_mixture_fit_best_params(): rand_data = RandomData(rng) n_components = rand_data.n_components n_init = 10 - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] - g = GaussianMixture(n_components=n_components, n_init=1, - max_iter=100, reg_covar=0, random_state=rng, - covariance_type=cov_type) + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] + g = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0, + random_state=rng, covariance_type=covar_type) ll = [] for _ in range(n_init): g.fit(X) ll.append(g.score(X)) ll = np.array(ll) g_best = GaussianMixture(n_components=n_components, - n_init=n_init, max_iter=100, reg_covar=0, - random_state=rng, covariance_type=cov_type) + n_init=n_init, reg_covar=0, random_state=rng, + covariance_type=covar_type) g_best.fit(X) assert_almost_equal(ll.min(), g_best.score(X)) @@ -577,11 +605,11 @@ def test_gaussian_mixture_fit_convergence_warning(): rand_data = RandomData(rng, scale=1) n_components = rand_data.n_components max_iter = 1 - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] g = GaussianMixture(n_components=n_components, n_init=1, max_iter=max_iter, reg_covar=0, random_state=rng, - covariance_type=cov_type) + covariance_type=covar_type) assert_warns_message(ConvergenceWarning, 'Initialization %d did not converged. ' 'Try different init parameters, ' @@ -659,14 +687,14 @@ def test_gaussian_mixture_verbose(): rng = np.random.RandomState(0) rand_data = RandomData(rng) n_components = rand_data.n_components - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] - g = GaussianMixture(n_components=n_components, n_init=1, - max_iter=100, reg_covar=0, random_state=rng, - covariance_type=cov_type, verbose=1) - h = GaussianMixture(n_components=n_components, n_init=1, - max_iter=100, reg_covar=0, random_state=rng, - covariance_type=cov_type, verbose=2) + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] + g = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0, + random_state=rng, covariance_type=covar_type, + verbose=1) + h = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0, + random_state=rng, covariance_type=covar_type, + verbose=2) old_stdout = sys.stdout sys.stdout = StringIO() try: @@ -699,7 +727,7 @@ def test_warm_start(): assert_almost_equal(g.weights_, h.weights_) assert_almost_equal(g.means_, h.means_) - assert_almost_equal(g.covariances_, h.covariances_) + assert_almost_equal(g.precisions_, h.precisions_) assert_greater(score2, score1) # Assert that by using warm_start we can converge to a good solution @@ -720,16 +748,16 @@ def test_warm_start(): def test_score(): - cov_type = 'full' + covar_type = 'full' rng = np.random.RandomState(0) rand_data = RandomData(rng, scale=7) n_components = rand_data.n_components - X = rand_data.X[cov_type] + X = rand_data.X[covar_type] # Check the error message if we don't call fit gmm1 = GaussianMixture(n_components=n_components, n_init=1, max_iter=1, reg_covar=0, random_state=rng, - covariance_type=cov_type) + covariance_type=covar_type) assert_raise_message(NotFittedError, "This GaussianMixture instance is not fitted " "yet. Call 'fit' with appropriate arguments " @@ -744,23 +772,22 @@ def test_score(): assert_almost_equal(gmm_score, gmm_score_proba) # Check if the score increase - gmm2 = GaussianMixture(n_components=n_components, n_init=1, - max_iter=1000, reg_covar=0, random_state=rng, - covariance_type=cov_type).fit(X) + gmm2 = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0, + random_state=rng, + covariance_type=covar_type).fit(X) assert_greater(gmm2.score(X), gmm1.score(X)) def test_score_samples(): - cov_type = 'full' + covar_type = 'full' rng = np.random.RandomState(0) rand_data = RandomData(rng, scale=7) n_components = rand_data.n_components - X = rand_data.X[cov_type] + X = rand_data.X[covar_type] # Check the error message if we don't call fit - gmm = GaussianMixture(n_components=n_components, n_init=1, - max_iter=100, reg_covar=0, random_state=rng, - covariance_type=cov_type) + gmm = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0, + random_state=rng, covariance_type=covar_type) assert_raise_message(NotFittedError, "This GaussianMixture instance is not fitted " "yet. Call 'fit' with appropriate arguments " @@ -777,10 +804,10 @@ def test_monotonic_likelihood(): rand_data = RandomData(rng, scale=7) n_components = rand_data.n_components - for cov_type in COVARIANCE_TYPE: - X = rand_data.X[cov_type] + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] gmm = GaussianMixture(n_components=n_components, - covariance_type=cov_type, reg_covar=0, + covariance_type=covar_type, reg_covar=0, warm_start=True, max_iter=1, random_state=rng, tol=1e-7) current_log_likelihood = -np.infty @@ -810,9 +837,9 @@ def test_regularisation(): X = np.vstack((np.ones((n_samples // 2, n_features)), np.zeros((n_samples // 2, n_features)))) - for cov_type in COVARIANCE_TYPE: - gmm = GaussianMixture(n_components=n_samples, covariance_type=cov_type, - reg_covar=0, random_state=rng) + for covar_type in COVARIANCE_TYPE: + gmm = GaussianMixture(n_components=n_samples, reg_covar=0, + covariance_type=covar_type, random_state=rng) with warnings.catch_warnings(): warnings.simplefilter("ignore", RuntimeWarning) @@ -823,3 +850,25 @@ def test_regularisation(): "or increase reg_covar.", gmm.fit, X) gmm.set_params(reg_covar=1e-6).fit(X) + + +def test_property(): + rng = np.random.RandomState(0) + rand_data = RandomData(rng, scale=7) + n_components = rand_data.n_components + + for covar_type in COVARIANCE_TYPE: + X = rand_data.X[covar_type] + gmm = GaussianMixture(n_components=n_components, + covariance_type=covar_type, random_state=rng) + gmm.fit(X) + print(covar_type) + if covar_type is 'full': + for prec, covar in zip(gmm.precisions_, gmm.covariances_): + + assert_array_almost_equal(linalg.inv(prec), covar) + elif covar_type is 'tied': + assert_array_almost_equal(linalg.inv(gmm.precisions_), + gmm.covariances_) + else: + assert_array_almost_equal(gmm.precisions_, 1. / gmm.covariances_)