diff --git a/doc/modules/mixture.rst b/doc/modules/mixture.rst
index 6d6494b9bc4aafb7eea7b919a15bc54b9f27b38f..8249b217da5be522328b2d21e4346a05b1a7f1b2 100644
--- a/doc/modules/mixture.rst
+++ b/doc/modules/mixture.rst
@@ -19,8 +19,8 @@ components are also provided.
    :align: center
    :scale: 50%
 
-   **Two-component Gaussian mixture model:** *data points, and equi-probability surfaces of
-   the model.*
+   **Two-component Gaussian mixture model:** *data points, and equi-probability
+   surfaces of the model.*
 
 A Gaussian mixture model is a probabilistic model that assumes all the
 data points are generated from a mixture of a finite number of
@@ -51,9 +51,9 @@ the :meth:`GaussianMixture.predict` method.
     sample belonging to the various Gaussians may be retrieved using the
     :meth:`GaussianMixture.predict_proba` method.
 
-The :class:`GaussianMixture` comes with different options to constrain the covariance
-of the difference classes estimated: spherical, diagonal, tied or full
-covariance.
+The :class:`GaussianMixture` comes with different options to constrain the
+covariance of the difference classes estimated: spherical, diagonal, tied or
+full covariance.
 
 .. figure:: ../auto_examples/mixture/images/plot_gmm_covariances_001.png
    :target: ../auto_examples/mixture/plot_gmm_covariances.html
@@ -72,7 +72,7 @@ Pros and cons of class :class:`GaussianMixture`
 -----------------------------------------------
 
 Pros
-.....
+....
 
 :Speed: It is the fastest algorithm for learning mixture models
 
diff --git a/examples/mixture/plot_gmm_covariances.py b/examples/mixture/plot_gmm_covariances.py
index e3c8d8b68b43ae1daefb201948faa93c83d8a944..dbd5be50f93e1806f0b8268a20a90cf17488838d 100644
--- a/examples/mixture/plot_gmm_covariances.py
+++ b/examples/mixture/plot_gmm_covariances.py
@@ -47,14 +47,14 @@ colors = ['navy', 'turquoise', 'darkorange']
 def make_ellipses(gmm, ax):
     for n, color in enumerate(colors):
         if gmm.covariance_type == 'full':
-            covars = gmm.covariances_[n][:2, :2]
+            covariances = gmm.covariances_[n][:2, :2]
         elif gmm.covariance_type == 'tied':
-            covars = gmm.covariances_[:2, :2]
+            covariances = gmm.covariances_[:2, :2]
         elif gmm.covariance_type == 'diag':
-            covars = np.diag(gmm.covariances_[n][:2])
+            covariances = np.diag(gmm.covariances_[n][:2])
         elif gmm.covariance_type == 'spherical':
-            covars = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
-        v, w = np.linalg.eigh(covars)
+            covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]
+        v, w = np.linalg.eigh(covariances)
         u = w[0] / np.linalg.norm(w[0])
         angle = np.arctan2(u[1], u[0])
         angle = 180 * angle / np.pi  # convert to degrees
@@ -82,9 +82,9 @@ y_test = iris.target[test_index]
 n_classes = len(np.unique(y_train))
 
 # Try GMMs using different types of covariances.
-estimators = dict((covar_type, GaussianMixture(n_components=n_classes,
-                   covariance_type=covar_type, max_iter=20))
-                  for covar_type in ['spherical', 'diag', 'tied', 'full'])
+estimators = dict((cov_type, GaussianMixture(n_components=n_classes,
+                   covariance_type=cov_type, max_iter=20, random_state=0))
+                  for cov_type in ['spherical', 'diag', 'tied', 'full'])
 
 n_estimators = len(estimators)
 
diff --git a/examples/mixture/plot_gmm_selection.py b/examples/mixture/plot_gmm_selection.py
index 747dc0d8a90c7adb8065277f4a135dd75f534ea1..3ccaba5262c0d354d42daeeac8d09a149bbf2093 100644
--- a/examples/mixture/plot_gmm_selection.py
+++ b/examples/mixture/plot_gmm_selection.py
@@ -75,9 +75,9 @@ spl.legend([b[0] for b in bars], cv_types)
 # Plot the winner
 splot = plt.subplot(2, 1, 2)
 Y_ = clf.predict(X)
-for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covariances_,
-                                             color_iter)):
-    v, w = linalg.eigh(covar)
+for i, (mean, cov, color) in enumerate(zip(clf.means_, clf.covariances_,
+                                           color_iter)):
+    v, w = linalg.eigh(cov)
     if not np.any(Y_ == i):
         continue
     plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color=color)
diff --git a/sklearn/mixture/gaussian_mixture.py b/sklearn/mixture/gaussian_mixture.py
index b19928dbcbdbcc16fb6dd5f692bef5f975896ef2..87e215ee2c9f134d223134f7651571d4340f11f5 100644
--- a/sklearn/mixture/gaussian_mixture.py
+++ b/sklearn/mixture/gaussian_mixture.py
@@ -36,14 +36,14 @@ def _check_weights(weights, n_components):
     _check_shape(weights, (n_components,), 'weights')
 
     # check range
-    if (any(np.less(weights, 0)) or
-            any(np.greater(weights, 1))):
+    if (any(np.less(weights, 0.)) or
+            any(np.greater(weights, 1.))):
         raise ValueError("The parameter 'weights' should be in the range "
                          "[0, 1], but got max value %.5f, min value %.5f"
                          % (np.min(weights), np.max(weights)))
 
     # check normalization
-    if not np.allclose(np.abs(1 - np.sum(weights)), 0.0):
+    if not np.allclose(np.abs(1. - np.sum(weights)), 0.):
         raise ValueError("The parameter 'weights' should be normalized, "
                          "but got sum(weights) = %.5f" % np.sum(weights))
     return weights
@@ -72,33 +72,33 @@ def _check_means(means, n_components, n_features):
     return means
 
 
-def _check_covariance_matrix(covariance, covariance_type):
-    """Check a covariance matrix is symmetric and positive-definite."""
-    if (not np.allclose(covariance, covariance.T) or
-            np.any(np.less_equal(linalg.eigvalsh(covariance), .0))):
-        raise ValueError("'%s covariance' should be symmetric, "
-                         "positive-definite" % covariance_type)
+def _check_precision_positivity(precision, covariance_type):
+    """Check a precision vector is positive-definite."""
+    if np.any(np.less_equal(precision, 0.0)):
+        raise ValueError("'%s precision' should be "
+                         "positive" % covariance_type)
 
 
-def _check_covariance_positivity(covariance, covariance_type):
-    """Check a covariance vector is positive-definite."""
-    if np.any(np.less_equal(covariance, 0.0)):
-        raise ValueError("'%s covariance' should be "
-                         "positive" % covariance_type)
+def _check_precision_matrix(precision, covariance_type):
+    """Check a precision matrix is symmetric and positive-definite."""
+    if not (np.allclose(precision, precision.T) and
+            np.all(linalg.eigvalsh(precision) > 0.)):
+        raise ValueError("'%s precision' should be symmetric, "
+                         "positive-definite" % covariance_type)
 
 
-def _check_covariances_full(covariances, covariance_type):
-    """Check the covariance matrices are symmetric and positive-definite."""
-    for k, cov in enumerate(covariances):
-        _check_covariance_matrix(cov, covariance_type)
+def _check_precisions_full(precisions, covariance_type):
+    """Check the precision matrices are symmetric and positive-definite."""
+    for k, prec in enumerate(precisions):
+        prec = _check_precision_matrix(prec, covariance_type)
 
 
-def _check_covariances(covariances, covariance_type, n_components, n_features):
-    """Validate user provided covariances.
+def _check_precisions(precisions, covariance_type, n_components, n_features):
+    """Validate user provided precisions.
 
     Parameters
     ----------
-    covariances : array-like,
+    precisions : array-like,
         'full' : shape of (n_components, n_features, n_features)
         'tied' : shape of (n_features, n_features)
         'diag' : shape of (n_components, n_features)
@@ -114,33 +114,37 @@ def _check_covariances(covariances, covariance_type, n_components, n_features):
 
     Returns
     -------
-    covariances : array
+    precisions : array
     """
-    covariances = check_array(covariances, dtype=[np.float64, np.float32],
-                              ensure_2d=False,
-                              allow_nd=covariance_type is 'full')
-
-    covariances_shape = {'full': (n_components, n_features, n_features),
-                         'tied': (n_features, n_features),
-                         'diag': (n_components, n_features),
-                         'spherical': (n_components,)}
-    _check_shape(covariances, covariances_shape[covariance_type],
-                 '%s covariance' % covariance_type)
+    precisions = check_array(precisions, dtype=[np.float64, np.float32],
+                             ensure_2d=False,
+                             allow_nd=covariance_type is 'full')
 
-    check_functions = {'full': _check_covariances_full,
-                       'tied': _check_covariance_matrix,
-                       'diag': _check_covariance_positivity,
-                       'spherical': _check_covariance_positivity}
-    check_functions[covariance_type](covariances, covariance_type)
+    precisions_shape = {'full': (n_components, n_features, n_features),
+                        'tied': (n_features, n_features),
+                        'diag': (n_components, n_features),
+                        'spherical': (n_components,)}
+    _check_shape(precisions, precisions_shape[covariance_type],
+                 '%s precision' % covariance_type)
 
-    return covariances
+    _check_precisions = {'full': _check_precisions_full,
+                         'tied': _check_precision_matrix,
+                         'diag': _check_precision_positivity,
+                         'spherical': _check_precision_positivity}
+    _check_precisions[covariance_type](precisions, covariance_type)
+    return precisions
 
 
 ###############################################################################
 # Gaussian mixture parameters estimators (used by the M-Step)
+ESTIMATE_PRECISION_ERROR_MESSAGE = ("The algorithm has diverged because of "
+                                    "too few samples per components. Try to "
+                                    "decrease the number of components, "
+                                    "or increase reg_covar.")
 
-def _estimate_gaussian_covariance_full(resp, X, nk, means, reg_covar):
-    """Estimate the full covariance matrices.
+
+def _estimate_gaussian_precisions_cholesky_full(resp, X, nk, means, reg_covar):
+    """Estimate the full precision matrices.
 
     Parameters
     ----------
@@ -156,20 +160,27 @@ def _estimate_gaussian_covariance_full(resp, X, nk, means, reg_covar):
 
     Returns
     -------
-    covariances : array, shape (n_components, n_features, n_features)
+    precisions_chol : array, shape (n_components, n_features, n_features)
+        The cholesky decomposition of the precision matrix.
     """
-    n_features = X.shape[1]
-    n_components = means.shape[0]
-    covariances = np.empty((n_components, n_features, n_features))
+    n_components, n_features = means.shape
+    precisions_chol = np.empty((n_components, n_features, n_features))
     for k in range(n_components):
         diff = X - means[k]
-        covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
-        covariances[k].flat[::n_features + 1] += reg_covar
-    return covariances
+        covariance = np.dot(resp[:, k] * diff.T, diff) / nk[k]
+        covariance.flat[::n_features + 1] += reg_covar
+        try:
+            cov_chol = linalg.cholesky(covariance, lower=True)
+        except linalg.LinAlgError:
+            raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE)
+        precisions_chol[k] = linalg.solve_triangular(cov_chol,
+                                                     np.eye(n_features),
+                                                     lower=True).T
+    return precisions_chol
 
 
-def _estimate_gaussian_covariance_tied(resp, X, nk, means, reg_covar):
-    """Estimate the tied covariance matrix.
+def _estimate_gaussian_precisions_cholesky_tied(resp, X, nk, means, reg_covar):
+    """Estimate the tied precision matrix.
 
     Parameters
     ----------
@@ -185,18 +196,26 @@ def _estimate_gaussian_covariance_tied(resp, X, nk, means, reg_covar):
 
     Returns
     -------
-    covariances : array, shape (n_features, n_features)
+    precisions_chol : array, shape (n_features, n_features)
+        The cholesky decomposition of the precision matrix.
     """
+    n_samples, n_features = X.shape
     avg_X2 = np.dot(X.T, X)
     avg_means2 = np.dot(nk * means.T, means)
     covariances = avg_X2 - avg_means2
-    covariances /= X.shape[0]
+    covariances /= n_samples
     covariances.flat[::len(covariances) + 1] += reg_covar
-    return covariances
+    try:
+        cov_chol = linalg.cholesky(covariances, lower=True)
+    except linalg.LinAlgError:
+        raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE)
+    precisions_chol = linalg.solve_triangular(cov_chol, np.eye(n_features),
+                                              lower=True).T
+    return precisions_chol
 
 
-def _estimate_gaussian_covariance_diag(resp, X, nk, means, reg_covar):
-    """Estimate the diagonal covariance matrices.
+def _estimate_gaussian_precisions_cholesky_diag(resp, X, nk, means, reg_covar):
+    """Estimate the diagonal precision matrices.
 
     Parameters
     ----------
@@ -212,16 +231,21 @@ def _estimate_gaussian_covariance_diag(resp, X, nk, means, reg_covar):
 
     Returns
     -------
-    covariances : array, shape (n_components, n_features)
+    precisions_chol : array, shape (n_components, n_features)
+        The cholesky decomposition of the precision matrix.
     """
     avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis]
     avg_means2 = means ** 2
     avg_X_means = means * np.dot(resp.T, X) / nk[:, np.newaxis]
-    return avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar
+    covariances = avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar
+    if np.any(np.less_equal(covariances, 0.0)):
+        raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE)
+    return 1. / np.sqrt(covariances)
 
 
-def _estimate_gaussian_covariance_spherical(resp, X, nk, means, reg_covar):
-    """Estimate the spherical covariance matrices.
+def _estimate_gaussian_precisions_cholesky_spherical(resp, X, nk, means,
+                                                     reg_covar):
+    """Estimate the spherical precision matrices.
 
     Parameters
     ----------
@@ -237,11 +261,16 @@ def _estimate_gaussian_covariance_spherical(resp, X, nk, means, reg_covar):
 
     Returns
     -------
-    covariances : array, shape (n_components,)
+    precisions_chol : array, shape (n_components,)
+        The cholesky decomposition of the precision matrix.
     """
-    covariances = _estimate_gaussian_covariance_diag(resp, X, nk, means,
-                                                     reg_covar)
-    return covariances.mean(axis=1)
+    avg_X2 = np.dot(resp.T, X * X) / nk[:, np.newaxis]
+    avg_means2 = means ** 2
+    avg_X_means = means * np.dot(resp.T, X) / nk[:, np.newaxis]
+    covariances = (avg_X2 - 2 * avg_X_means + avg_means2 + reg_covar).mean(1)
+    if np.any(np.less_equal(covariances, 0.0)):
+        raise ValueError(ESTIMATE_PRECISION_ERROR_MESSAGE)
+    return 1. / np.sqrt(covariances)
 
 
 def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type):
@@ -256,10 +285,10 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type):
         The responsibilities for each data sample in X.
 
     reg_covar : float
-        The regularization added to each covariance matrices.
+        The regularization added to the diagonal of the covariance matrices.
 
     covariance_type : {'full', 'tied', 'diag', 'spherical'}
-        The type of covariance matrices.
+        The type of precision matrices.
 
     Returns
     -------
@@ -269,29 +298,25 @@ def _estimate_gaussian_parameters(X, resp, reg_covar, covariance_type):
     means : array, shape (n_components, n_features)
         The centers of the current components.
 
-    covariances : array
-        The sample covariances of the current components.
-        The shape depends of the covariance_type.
+    precisions_cholesky : array
+        The cholesky decomposition of sample precisions of the current
+        components. The shape depends of the covariance_type.
     """
-    compute_covariance = {
-        "full": _estimate_gaussian_covariance_full,
-        "tied": _estimate_gaussian_covariance_tied,
-        "diag": _estimate_gaussian_covariance_diag,
-        "spherical": _estimate_gaussian_covariance_spherical}
-
     nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps
     means = np.dot(resp.T, X) / nk[:, np.newaxis]
-    covariances = compute_covariance[covariance_type](
-        resp, X, nk, means, reg_covar)
-
-    return nk, means, covariances
+    precs_chol = {"full": _estimate_gaussian_precisions_cholesky_full,
+                  "tied": _estimate_gaussian_precisions_cholesky_tied,
+                  "diag": _estimate_gaussian_precisions_cholesky_diag,
+                  "spherical": _estimate_gaussian_precisions_cholesky_spherical
+                  }[covariance_type](resp, X, nk, means, reg_covar)
+    return nk, means, precs_chol
 
 
 ###############################################################################
 # Gaussian mixture probability estimators
 
-def _estimate_log_gaussian_prob_full(X, means, covariances):
-    """Estimate the log Gaussian probability for 'full' covariance.
+def _estimate_log_gaussian_prob_full(X, means, precisions_chol):
+    """Estimate the log Gaussian probability for 'full' precision.
 
     Parameters
     ----------
@@ -299,33 +324,26 @@ def _estimate_log_gaussian_prob_full(X, means, covariances):
 
     means : array-like, shape (n_components, n_features)
 
-    covariances : array-like, shape (n_components, n_features, n_features)
+    precisions_chol : array-like, shape (n_components, n_features, n_features)
+        Cholesky decompositions of the precision matrices.
 
     Returns
     -------
     log_prob : array, shape (n_samples, n_components)
     """
     n_samples, n_features = X.shape
-    n_components = means.shape[0]
+    n_components, _ = means.shape
     log_prob = np.empty((n_samples, n_components))
-    for k, (mu, cov) in enumerate(zip(means, covariances)):
-        try:
-            cov_chol = linalg.cholesky(cov, lower=True)
-        except linalg.LinAlgError:
-            raise ValueError("The algorithm has diverged because of too "
-                             "few samples per components. "
-                             "Try to decrease the number of components, or "
-                             "increase reg_covar.")
-        cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol)))
-        cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T, lower=True).T
-        log_prob[:, k] = - .5 * (n_features * np.log(2. * np.pi) +
-                                 cv_log_det +
-                                 np.sum(np.square(cv_sol), axis=1))
+    for k, (mu, prec_chol) in enumerate(zip(means, precisions_chol)):
+        log_det = -2. * np.sum(np.log(np.diagonal(prec_chol)))
+        y = np.dot(X - mu, prec_chol)
+        log_prob[:, k] = -.5 * (n_features * np.log(2. * np.pi) + log_det +
+                                np.sum(np.square(y), axis=1))
     return log_prob
 
 
-def _estimate_log_gaussian_prob_tied(X, means, covariances):
-    """Estimate the log Gaussian probability for 'tied' covariance.
+def _estimate_log_gaussian_prob_tied(X, means, precision_chol):
+    """Estimate the log Gaussian probability for 'tied' precision.
 
     Parameters
     ----------
@@ -333,33 +351,26 @@ def _estimate_log_gaussian_prob_tied(X, means, covariances):
 
     means : array-like, shape (n_components, n_features)
 
-    covariances : array-like, shape (n_features, n_features)
+    precision_chol : array-like, shape (n_features, n_features)
+        Cholesky decomposition of the precision matrix.
 
     Returns
     -------
     log_prob : array-like, shape (n_samples, n_components)
     """
     n_samples, n_features = X.shape
-    n_components = means.shape[0]
+    n_components, _ = means.shape
     log_prob = np.empty((n_samples, n_components))
-    try:
-        cov_chol = linalg.cholesky(covariances, lower=True)
-    except linalg.LinAlgError:
-        raise ValueError("The algorithm has diverged because of too "
-                         "few samples per components. "
-                         "Try to decrease the number of components, or "
-                         "increase reg_covar.")
-    cv_log_det = 2. * np.sum(np.log(np.diagonal(cov_chol)))
+    log_det = -2. * np.sum(np.log(np.diagonal(precision_chol)))
     for k, mu in enumerate(means):
-        cv_sol = linalg.solve_triangular(cov_chol, (X - mu).T,
-                                         lower=True).T
-        log_prob[:, k] = np.sum(np.square(cv_sol), axis=1)
-    log_prob = - .5 * (n_features * np.log(2. * np.pi) + cv_log_det + log_prob)
+        y = np.dot(X - mu, precision_chol)
+        log_prob[:, k] = np.sum(np.square(y), axis=1)
+    log_prob = -.5 * (n_features * np.log(2. * np.pi) + log_det + log_prob)
     return log_prob
 
 
-def _estimate_log_gaussian_prob_diag(X, means, covariances):
-    """Estimate the log Gaussian probability for 'diag' covariance.
+def _estimate_log_gaussian_prob_diag(X, means, precisions_chol):
+    """Estimate the log Gaussian probability for 'diag' precision.
 
     Parameters
     ----------
@@ -367,28 +378,25 @@ def _estimate_log_gaussian_prob_diag(X, means, covariances):
 
     means : array-like, shape (n_components, n_features)
 
-    covariances : array-like, shape (n_components, n_features)
+    precisions_chol : array-like, shape (n_components, n_features)
+        Cholesky decompositions of the precision matrices.
 
     Returns
     -------
     log_prob : array-like, shape (n_samples, n_components)
     """
-    if np.any(np.less_equal(covariances, 0.0)):
-        raise ValueError("The algorithm has diverged because of too "
-                         "few samples per components. "
-                         "Try to decrease the number of components, or "
-                         "increase reg_covar.")
     n_samples, n_features = X.shape
-    log_prob = - .5 * (n_features * np.log(2. * np.pi) +
-                       np.sum(np.log(covariances), 1) +
-                       np.sum((means ** 2 / covariances), 1) -
-                       2. * np.dot(X, (means / covariances).T) +
-                       np.dot(X ** 2, (1. / covariances).T))
+    precisions = precisions_chol ** 2
+    log_prob = -.5 * (n_features * np.log(2. * np.pi) -
+                      np.sum(np.log(precisions), 1) +
+                      np.sum((means ** 2 * precisions), 1) -
+                      2. * np.dot(X, (means * precisions).T) +
+                      np.dot(X ** 2, precisions.T))
     return log_prob
 
 
-def _estimate_log_gaussian_prob_spherical(X, means, covariances):
-    """Estimate the log Gaussian probability for 'spherical' covariance.
+def _estimate_log_gaussian_prob_spherical(X, means, precisions_chol):
+    """Estimate the log Gaussian probability for 'spherical' precision.
 
     Parameters
     ----------
@@ -396,23 +404,20 @@ def _estimate_log_gaussian_prob_spherical(X, means, covariances):
 
     means : array-like, shape (n_components, n_features)
 
-    covariances : array-like, shape (n_components, )
+    precisions_chol : array-like, shape (n_components, )
+        Cholesky decompositions of the precision matrices.
 
     Returns
     -------
     log_prob : array-like, shape (n_samples, n_components)
     """
-    if np.any(np.less_equal(covariances, 0.0)):
-        raise ValueError("The algorithm has diverged because of too "
-                         "few samples per components. "
-                         "Try to decrease the number of components, or "
-                         "increase reg_covar.")
     n_samples, n_features = X.shape
-    log_prob = - .5 * (n_features * np.log(2 * np.pi) +
-                       n_features * np.log(covariances) +
-                       np.sum(means ** 2, 1) / covariances -
-                       2 * np.dot(X, means.T / covariances) +
-                       np.outer(np.sum(X ** 2, axis=1), 1. / covariances))
+    precisions = precisions_chol ** 2
+    log_prob = -.5 * (n_features * np.log(2 * np.pi) -
+                      n_features * np.log(precisions) +
+                      np.sum(means ** 2, 1) * precisions -
+                      2 * np.dot(X, means.T * precisions) +
+                      np.outer(np.sum(X ** 2, axis=1), precisions))
     return log_prob
 
 
@@ -453,7 +458,7 @@ class GaussianMixture(BaseMixture):
 
     init_params : {'kmeans', 'random'}, defaults to 'kmeans'.
         The method used to initialize the weights, the means and the
-        covariances.
+        precisions.
         Must be one of::
         'kmeans' : responsibilities are initialized using kmeans.
         'random' : responsibilities are initialized randomly.
@@ -466,9 +471,10 @@ class GaussianMixture(BaseMixture):
         The user-provided initial means, defaults to None,
         If it None, means are initialized using the `init_params` method.
 
-    covariances_init: array-like, optional.
-        The user-provided initial covariances, defaults to None.
-        If it None, covariances are initialized using the 'init_params' method.
+    precisions_init: array-like, optional.
+        The user-provided initial precisions (inverse of the covariance
+        matrices), defaults to None.
+        If it None, precisions are initialized using the 'init_params' method.
         The shape depends on 'covariance_type'::
             (n_components,)                        if 'spherical',
             (n_features, n_features)               if 'tied',
@@ -493,11 +499,9 @@ class GaussianMixture(BaseMixture):
     ----------
     weights_ : array, shape (n_components,)
         The weights of each mixture components.
-        `weights_` will not exist before a call to fit.
 
     means_ : array, shape (n_components, n_features)
         The mean of each mixture component.
-        `means_` will not exist before a call to fit.
 
     covariances_ : array
         The covariance of each mixture component.
@@ -506,20 +510,43 @@ class GaussianMixture(BaseMixture):
             (n_features, n_features)               if 'tied',
             (n_components, n_features)             if 'diag',
             (n_components, n_features, n_features) if 'full'
-        `covariances_` will not exist before a call to fit.
+
+    precisions_ : array
+        The precision matrices for each component in the mixture. A precision
+        matrix is the inverse of a covariance matrix. A covariance matrix is
+        symmetric positive definite so the mixture of Gaussian can be
+        equivalently parameterized by the precision matrices. Storing the
+        precision matrices instead of the covariance matrices makes it more
+        efficient to compute the log-likelihood of new samples at test time.
+        The shape depends on `covariance_type`::
+            (n_components,)                        if 'spherical',
+            (n_features, n_features)               if 'tied',
+            (n_components, n_features)             if 'diag',
+            (n_components, n_features, n_features) if 'full'
+
+    precisions_cholesky_ : array
+        The cholesky decomposition of the precision matrices of each mixture
+        component. A precision matrix is the inverse of a covariance matrix.
+        A covariance matrix is symmetric positive definite so the mixture of
+        Gaussian can be equivalently parameterized by the precision matrices.
+        Storing the precision matrices instead of the covariance matrices makes
+        it more efficient to compute the log-likelihood of new samples at test
+        time. The shape depends on `covariance_type`::
+            (n_components,)                        if 'spherical',
+            (n_features, n_features)               if 'tied',
+            (n_components, n_features)             if 'diag',
+            (n_components, n_features, n_features) if 'full'
 
     converged_ : bool
         True when convergence was reached in fit(), False otherwise.
-        `converged_` will not exist before a call to fit.
 
     n_iter_ : int
         Number of step used by the best fit of EM to reach the convergence.
-        `n_iter_`  will not exist before a call to fit.
     """
 
     def __init__(self, n_components=1, covariance_type='full', tol=1e-3,
                  reg_covar=1e-6, max_iter=100, n_init=1, init_params='kmeans',
-                 weights_init=None, means_init=None, covariances_init=None,
+                 weights_init=None, means_init=None, precisions_init=None,
                  random_state=None, warm_start=False,
                  verbose=0, verbose_interval=10):
         super(GaussianMixture, self).__init__(
@@ -531,10 +558,11 @@ class GaussianMixture(BaseMixture):
         self.covariance_type = covariance_type
         self.weights_init = weights_init
         self.means_init = means_init
-        self.covariances_init = covariances_init
+        self.precisions_init = precisions_init
 
     def _check_parameters(self, X):
         """Check the Gaussian mixture parameters are well defined."""
+        _, n_features = X.shape
         if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']:
             raise ValueError("Invalid value for 'covariance_type': %s "
                              "'covariance_type' should be in "
@@ -547,13 +575,13 @@ class GaussianMixture(BaseMixture):
 
         if self.means_init is not None:
             self.means_init = _check_means(self.means_init,
-                                           self.n_components, X.shape[1])
+                                           self.n_components, n_features)
 
-        if self.covariances_init is not None:
-            self.covariances_init = _check_covariances(self.covariances_init,
-                                                       self.covariance_type,
-                                                       self.n_components,
-                                                       X.shape[1])
+        if self.precisions_init is not None:
+            self.precisions_init = _check_precisions(self.precisions_init,
+                                                     self.covariance_type,
+                                                     self.n_components,
+                                                     n_features)
 
     def _initialize(self, X, resp):
         """Initialization of the Gaussian mixture parameters.
@@ -564,60 +592,92 @@ class GaussianMixture(BaseMixture):
 
         resp : array-like, shape (n_samples, n_components)
         """
-        weights, means, covariances = _estimate_gaussian_parameters(
+        n_samples, _ = X.shape
+
+        weights, means, precisions_cholesky = _estimate_gaussian_parameters(
             X, resp, self.reg_covar, self.covariance_type)
-        weights /= X.shape[0]
+        weights /= n_samples
 
         self.weights_ = (weights if self.weights_init is None
                          else self.weights_init)
         self.means_ = means if self.means_init is None else self.means_init
-        self.covariances_ = (covariances if self.covariances_init is None
-                             else self.covariances_init)
+
+        if self.precisions_init is None:
+            self.precisions_cholesky_ = precisions_cholesky
+        elif self.covariance_type is 'full':
+            self.precisions_cholesky_ = np.array(
+                [linalg.cholesky(prec_init, lower=True)
+                 for prec_init in self.precisions_init])
+        elif self.covariance_type is 'tied':
+            self.precisions_cholesky_ = linalg.cholesky(self.precisions_init,
+                                                        lower=True)
+        else:
+            self.precisions_cholesky_ = self.precisions_init
 
     def _e_step(self, X):
         log_prob_norm, _, log_resp = self._estimate_log_prob_resp(X)
         return np.mean(log_prob_norm), np.exp(log_resp)
 
     def _m_step(self, X, resp):
-        self.weights_, self.means_, self.covariances_ = (
+        self.weights_, self.means_, self.precisions_cholesky_ = (
             _estimate_gaussian_parameters(X, resp, self.reg_covar,
                                           self.covariance_type))
         self.weights_ /= X.shape[0]
 
     def _estimate_log_prob(self, X):
-        estimate_log_prob_functions = {
-            "full": _estimate_log_gaussian_prob_full,
-            "tied": _estimate_log_gaussian_prob_tied,
-            "diag": _estimate_log_gaussian_prob_diag,
-            "spherical": _estimate_log_gaussian_prob_spherical
-        }
-        return estimate_log_prob_functions[self.covariance_type](
-            X, self.means_, self.covariances_)
+        return {"full": _estimate_log_gaussian_prob_full,
+                "tied": _estimate_log_gaussian_prob_tied,
+                "diag": _estimate_log_gaussian_prob_diag,
+                "spherical": _estimate_log_gaussian_prob_spherical
+                }[self.covariance_type](X, self.means_,
+                                        self.precisions_cholesky_)
 
     def _estimate_log_weights(self):
         return np.log(self.weights_)
 
     def _check_is_fitted(self):
-        check_is_fitted(self, ['weights_', 'means_', 'covariances_'])
+        check_is_fitted(self, ['weights_', 'means_', 'precisions_cholesky_'])
 
     def _get_parameters(self):
-        return self.weights_, self.means_, self.covariances_
+        return self.weights_, self.means_, self.precisions_cholesky_
 
     def _set_parameters(self, params):
-        self.weights_, self.means_, self.covariances_ = params
+        self.weights_, self.means_, self.precisions_cholesky_ = params
+
+        # Attributes computation
+        _, n_features = self.means_.shape
+
+        if self.covariance_type is 'full':
+            self.precisions_ = np.empty(self.precisions_cholesky_.shape)
+            self.covariances_ = np.empty(self.precisions_cholesky_.shape)
+            for k, prec_chol in enumerate(self.precisions_cholesky_):
+                self.precisions_[k] = np.dot(prec_chol, prec_chol.T)
+                cov_chol = linalg.solve_triangular(prec_chol,
+                                                   np.eye(n_features))
+                self.covariances_[k] = np.dot(cov_chol.T, cov_chol)
+
+        elif self.covariance_type is 'tied':
+            self.precisions_ = np.dot(self.precisions_cholesky_,
+                                      self.precisions_cholesky_.T)
+            cov_chol = linalg.solve_triangular(self.precisions_cholesky_,
+                                               np.eye(n_features))
+            self.covariances_ = np.dot(cov_chol.T, cov_chol)
+        else:
+            self.precisions_ = self.precisions_cholesky_ ** 2
+            self.covariances_ = 1. / self.precisions_
 
     def _n_parameters(self):
         """Return the number of free parameters in the model."""
-        ndim = self.means_.shape[1]
+        _, n_features = self.means_.shape
         if self.covariance_type == 'full':
-            cov_params = self.n_components * ndim * (ndim + 1) / 2.
+            cov_params = self.n_components * n_features * (n_features + 1) / 2.
         elif self.covariance_type == 'diag':
-            cov_params = self.n_components * ndim
+            cov_params = self.n_components * n_features
         elif self.covariance_type == 'tied':
-            cov_params = ndim * (ndim + 1) / 2.
+            cov_params = n_features * (n_features + 1) / 2.
         elif self.covariance_type == 'spherical':
             cov_params = self.n_components
-        mean_params = ndim * self.n_components
+        mean_params = n_features * self.n_components
         return int(cov_params + mean_params + self.n_components - 1)
 
     def bic(self, X):
diff --git a/sklearn/mixture/tests/test_gaussian_mixture.py b/sklearn/mixture/tests/test_gaussian_mixture.py
index 64cdbe54c9f30992855612f551f573df54c07776..8e3e5516d7d279f4f42dbcb0b76a65fac223d405 100644
--- a/sklearn/mixture/tests/test_gaussian_mixture.py
+++ b/sklearn/mixture/tests/test_gaussian_mixture.py
@@ -3,18 +3,18 @@ import warnings
 
 import numpy as np
 
-from scipy import stats
+from scipy import stats, linalg
 
 from sklearn.covariance import EmpiricalCovariance
 from sklearn.datasets.samples_generator import make_spd_matrix
 from sklearn.externals.six.moves import cStringIO as StringIO
 from sklearn.metrics.cluster import adjusted_rand_score
 from sklearn.mixture.gaussian_mixture import GaussianMixture
-from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_diag
-from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_full
 from sklearn.mixture.gaussian_mixture import (
-    _estimate_gaussian_covariance_spherical)
-from sklearn.mixture.gaussian_mixture import _estimate_gaussian_covariance_tied
+    _estimate_gaussian_precisions_cholesky_full,
+    _estimate_gaussian_precisions_cholesky_tied,
+    _estimate_gaussian_precisions_cholesky_diag,
+    _estimate_gaussian_precisions_cholesky_spherical)
 from sklearn.exceptions import ConvergenceWarning, NotFittedError
 from sklearn.utils.extmath import fast_logdet
 from sklearn.utils.testing import assert_allclose
@@ -32,28 +32,28 @@ from sklearn.utils.testing import assert_warns_message
 COVARIANCE_TYPE = ['full', 'tied', 'diag', 'spherical']
 
 
-def generate_data(n_samples, n_features, weights, means, covariances,
+def generate_data(n_samples, n_features, weights, means, precisions,
                   covariance_type):
     rng = np.random.RandomState(0)
 
     X = []
     if covariance_type == 'spherical':
         for _, (w, m, c) in enumerate(zip(weights, means,
-                                          covariances['spherical'])):
+                                          precisions['spherical'])):
             X.append(rng.multivariate_normal(m, c * np.eye(n_features),
                                              int(np.round(w * n_samples))))
     if covariance_type == 'diag':
         for _, (w, m, c) in enumerate(zip(weights, means,
-                                          covariances['diag'])):
+                                          precisions['diag'])):
             X.append(rng.multivariate_normal(m, np.diag(c),
                                              int(np.round(w * n_samples))))
     if covariance_type == 'tied':
         for _, (w, m) in enumerate(zip(weights, means)):
-            X.append(rng.multivariate_normal(m, covariances['tied'],
+            X.append(rng.multivariate_normal(m, precisions['tied'],
                                              int(np.round(w * n_samples))))
     if covariance_type == 'full':
         for _, (w, m, c) in enumerate(zip(weights, means,
-                                          covariances['full'])):
+                                          precisions['full'])):
             X.append(rng.multivariate_normal(m, c,
                                              int(np.round(w * n_samples))))
 
@@ -75,13 +75,19 @@ class RandomData(object):
             'spherical': .5 + rng.rand(n_components),
             'diag': (.5 + rng.rand(n_components, n_features)) ** 2,
             'tied': make_spd_matrix(n_features, random_state=rng),
-            'full': np.array([make_spd_matrix(
-                n_features, random_state=rng) * .5
+            'full': np.array([
+                make_spd_matrix(n_features, random_state=rng) * .5
                 for _ in range(n_components)])}
+        self.precisions = {
+            'spherical': 1. / self.covariances['spherical'],
+            'diag': 1. / self.covariances['diag'],
+            'tied': linalg.inv(self.covariances['tied']),
+            'full': np.array([linalg.inv(covariance)
+                             for covariance in self.covariances['full']])}
 
         self.X = dict(zip(COVARIANCE_TYPE, [generate_data(
             n_samples, n_features, self.weights, self.means, self.covariances,
-            cov_type) for cov_type in COVARIANCE_TYPE]))
+            covar_type) for covar_type in COVARIANCE_TYPE]))
         self.Y = np.hstack([k * np.ones(int(np.round(w * n_samples)))
                             for k, w in enumerate(self.weights)])
 
@@ -198,9 +204,8 @@ def test_check_weights():
     g.weights_init = weights_bad_shape
     assert_raise_message(ValueError,
                          "The parameter 'weights' should have the shape of "
-                         "(%d,), "
-                         "but got %s" % (n_components,
-                                         str(weights_bad_shape.shape)),
+                         "(%d,), but got %s" %
+                         (n_components, str(weights_bad_shape.shape)),
                          g.fit, X)
 
     # Check bad range
@@ -253,27 +258,27 @@ def test_check_means():
     assert_array_equal(means, g.means_init)
 
 
-def test_check_covariances():
+def test_check_precisions():
     rng = np.random.RandomState(0)
     rand_data = RandomData(rng)
 
     n_components, n_features = rand_data.n_components, rand_data.n_features
 
-    # Define the bad covariances for each covariance_type
-    covariances_bad_shape = {
-        'full': rng.rand(n_components + 1, n_features, n_features),
-        'tied': rng.rand(n_features + 1, n_features + 1),
-        'diag': rng.rand(n_components + 1, n_features),
-        'spherical': rng.rand(n_components + 1)}
-
-    # Define not positive-definite covariances
-    covariances_not_pos = rng.rand(n_components, n_features, n_features)
-    covariances_not_pos[0] = np.eye(n_features)
-    covariances_not_pos[0, 0, 0] = -1.
-
-    covariances_not_positive = {
-        'full': covariances_not_pos,
-        'tied': covariances_not_pos[0],
+    # Define the bad precisions for each covariance_type
+    precisions_bad_shape = {
+        'full': np.ones((n_components + 1, n_features, n_features)),
+        'tied': np.ones((n_features + 1, n_features + 1)),
+        'diag': np.ones((n_components + 1, n_features)),
+        'spherical': np.ones((n_components + 1))}
+
+    # Define not positive-definite precisions
+    precisions_not_pos = np.ones((n_components, n_features, n_features))
+    precisions_not_pos[0] = np.eye(n_features)
+    precisions_not_pos[0, 0, 0] = -1.
+
+    precisions_not_positive = {
+        'full': precisions_not_pos,
+        'tied': precisions_not_pos[0],
         'diag': -1. * np.ones((n_components, n_features)),
         'spherical': -1. * np.ones(n_components)}
 
@@ -283,33 +288,35 @@ def test_check_covariances():
         'diag': 'positive',
         'spherical': 'positive'}
 
-    for cov_type in ['full', 'tied', 'diag', 'spherical']:
-        X = rand_data.X[cov_type]
+    for covar_type in COVARIANCE_TYPE:
+        X = RandomData(rng).X[covar_type]
         g = GaussianMixture(n_components=n_components,
-                            covariance_type=cov_type)
+                            covariance_type=covar_type,
+                            random_state=rng)
 
-        # Check covariance with bad shapes
-        g.covariances_init = covariances_bad_shape[cov_type]
+        # Check precisions with bad shapes
+        g.precisions_init = precisions_bad_shape[covar_type]
         assert_raise_message(ValueError,
-                             "The parameter '%s covariance' should have "
-                             "the shape of" % cov_type,
+                             "The parameter '%s precision' should have "
+                             "the shape of" % covar_type,
                              g.fit, X)
 
-        # Check not positive covariances
-        g.covariances_init = covariances_not_positive[cov_type]
+        # Check not positive precisions
+        g.precisions_init = precisions_not_positive[covar_type]
         assert_raise_message(ValueError,
-                             "'%s covariance' should be %s"
-                             % (cov_type, not_positive_errors[cov_type]),
+                             "'%s precision' should be %s"
+                             % (covar_type, not_positive_errors[covar_type]),
                              g.fit, X)
 
-        # Check the correct init of covariances_init
-        g.covariances_init = rand_data.covariances[cov_type]
+        # Check the correct init of precisions_init
+        g.precisions_init = rand_data.precisions[covar_type]
         g.fit(X)
-        assert_array_equal(rand_data.covariances[cov_type], g.covariances_init)
+        assert_array_equal(rand_data.precisions[covar_type], g.precisions_init)
 
 
 def test_suffstat_sk_full():
-    # compare the EmpiricalCovariance.covariance fitted on X*sqrt(resp)
+    # compare the precision matrix compute from the
+    # EmpiricalCovariance.covariance fitted on X*sqrt(resp)
     # with _sufficient_sk_full, n_components=1
     rng = np.random.RandomState(0)
     n_samples, n_features = 500, 2
@@ -320,21 +327,25 @@ def test_suffstat_sk_full():
     X_resp = np.sqrt(resp) * X
     nk = np.array([n_samples])
     xk = np.zeros((1, n_features))
-    covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0)
+    precs_pred = _estimate_gaussian_precisions_cholesky_full(resp, X,
+                                                             nk, xk, 0)
+    covars_pred = linalg.inv(np.dot(precs_pred[0], precs_pred[0].T))
     ecov = EmpiricalCovariance(assume_centered=True)
     ecov.fit(X_resp)
-    assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0)
-    assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0)
+    assert_almost_equal(ecov.error_norm(covars_pred, norm='frobenius'), 0)
+    assert_almost_equal(ecov.error_norm(covars_pred, norm='spectral'), 0)
 
     # special case 2, assuming resp are all ones
     resp = np.ones((n_samples, 1))
     nk = np.array([n_samples])
-    xk = X.mean().reshape((1, -1))
-    covars_pred = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0)
+    xk = X.mean(axis=0).reshape((1, -1))
+    precs_pred = _estimate_gaussian_precisions_cholesky_full(resp, X,
+                                                             nk, xk, 0)
+    covars_pred = linalg.inv(np.dot(precs_pred[0], precs_pred[0].T))
     ecov = EmpiricalCovariance(assume_centered=False)
     ecov.fit(X)
-    assert_almost_equal(ecov.error_norm(covars_pred[0], norm='frobenius'), 0)
-    assert_almost_equal(ecov.error_norm(covars_pred[0], norm='spectral'), 0)
+    assert_almost_equal(ecov.error_norm(covars_pred, norm='frobenius'), 0)
+    assert_almost_equal(ecov.error_norm(covars_pred, norm='spectral'), 0)
 
 
 def test_suffstat_sk_tied():
@@ -347,11 +358,18 @@ def test_suffstat_sk_tied():
     X = rng.rand(n_samples, n_features)
     nk = resp.sum(axis=0)
     xk = np.dot(resp.T, X) / nk[:, np.newaxis]
-    covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0)
+
+    precs_pred_full = _estimate_gaussian_precisions_cholesky_full(resp, X,
+                                                                  nk, xk, 0)
+    covars_pred_full = [linalg.inv(np.dot(precision_chol, precision_chol.T))
+                        for precision_chol in precs_pred_full]
     covars_pred_full = np.sum(nk[:, np.newaxis, np.newaxis] * covars_pred_full,
                               0) / n_samples
 
-    covars_pred_tied = _estimate_gaussian_covariance_tied(resp, X, nk, xk, 0)
+    precs_pred_tied = _estimate_gaussian_precisions_cholesky_tied(resp, X,
+                                                                  nk, xk, 0)
+    covars_pred_tied = linalg.inv(np.dot(precs_pred_tied, precs_pred_tied.T))
+
     ecov = EmpiricalCovariance()
     ecov.covariance_ = covars_pred_full
     assert_almost_equal(ecov.error_norm(covars_pred_tied, norm='frobenius'), 0)
@@ -368,14 +386,19 @@ def test_suffstat_sk_diag():
     X = rng.rand(n_samples, n_features)
     nk = resp.sum(axis=0)
     xk = np.dot(resp.T, X) / nk[:, np.newaxis]
-    covars_pred_full = _estimate_gaussian_covariance_full(resp, X, nk, xk, 0)
-    covars_pred_full = np.array([np.diag(np.diag(d)) for d in
-                                 covars_pred_full])
-    covars_pred_diag = _estimate_gaussian_covariance_diag(resp, X, nk, xk, 0)
-    covars_pred_diag = np.array([np.diag(d) for d in covars_pred_diag])
+    precs_pred_full = _estimate_gaussian_precisions_cholesky_full(resp, X,
+                                                                  nk, xk, 0)
+    covars_pred_full = [linalg.inv(np.dot(precision_chol, precision_chol.T))
+                        for precision_chol in precs_pred_full]
+
+    precs_pred_diag = _estimate_gaussian_precisions_cholesky_diag(resp, X,
+                                                                  nk, xk, 0)
+    covars_pred_diag = np.array([np.diag(1. / d) ** 2
+                                 for d in precs_pred_diag])
+
     ecov = EmpiricalCovariance()
     for (cov_full, cov_diag) in zip(covars_pred_full, covars_pred_diag):
-        ecov.covariance_ = cov_full
+        ecov.covariance_ = np.diag(np.diag(cov_full))
         assert_almost_equal(ecov.error_norm(cov_diag, norm='frobenius'), 0)
         assert_almost_equal(ecov.error_norm(cov_diag, norm='spectral'), 0)
 
@@ -391,11 +414,11 @@ def test_gaussian_suffstat_sk_spherical():
     resp = np.ones((n_samples, 1))
     nk = np.array([n_samples])
     xk = X.mean()
-    covars_pred_spherical = _estimate_gaussian_covariance_spherical(resp, X,
-                                                                    nk, xk, 0)
-    covars_pred_spherical2 = (np.dot(X.flatten().T, X.flatten()) /
-                              (n_features * n_samples))
-    assert_almost_equal(covars_pred_spherical, covars_pred_spherical2)
+    precs_pred_spherical = _estimate_gaussian_precisions_cholesky_spherical(
+        resp, X, nk, xk, 0)
+    covars_pred_spherical = (np.dot(X.flatten().T, X.flatten()) /
+                             (n_features * n_samples))
+    assert_almost_equal(1. / precs_pred_spherical ** 2, covars_pred_spherical)
 
 
 def _naive_lmvnpdf_diag(X, means, covars):
@@ -426,29 +449,33 @@ def test_gaussian_mixture_log_probabilities():
     log_prob_naive = _naive_lmvnpdf_diag(X, means, covars_diag)
 
     # full covariances
-    covars_full = np.array([np.diag(x) for x in covars_diag])
+    precs_full = np.array([np.diag(1. / np.sqrt(x)) for x in covars_diag])
 
-    log_prob = _estimate_log_gaussian_prob_full(X, means, covars_full)
+    log_prob = _estimate_log_gaussian_prob_full(X, means, precs_full)
     assert_array_almost_equal(log_prob, log_prob_naive)
 
     # diag covariances
-    log_prob = _estimate_log_gaussian_prob_diag(X, means, covars_diag)
+    precs_chol_diag = 1. / np.sqrt(covars_diag)
+    log_prob = _estimate_log_gaussian_prob_diag(X, means, precs_chol_diag)
     assert_array_almost_equal(log_prob, log_prob_naive)
 
     # tied
-    covars_tied = covars_full.mean(axis=0)
+    covars_tied = np.array([x for x in covars_diag]).mean(axis=0)
+    precs_tied = np.diag(np.sqrt(1. / covars_tied))
+
     log_prob_naive = _naive_lmvnpdf_diag(X, means,
-                                         [np.diag(covars_tied)] * n_components)
-    log_prob = _estimate_log_gaussian_prob_tied(X, means, covars_tied)
+                                         [covars_tied] * n_components)
+    log_prob = _estimate_log_gaussian_prob_tied(X, means, precs_tied)
+
     assert_array_almost_equal(log_prob, log_prob_naive)
 
     # spherical
     covars_spherical = covars_diag.mean(axis=1)
+    precs_spherical = 1. / np.sqrt(covars_diag.mean(axis=1))
     log_prob_naive = _naive_lmvnpdf_diag(X, means,
                                          [[k] * n_features for k in
                                           covars_spherical])
-    log_prob = _estimate_log_gaussian_prob_spherical(X, means,
-                                                     covars_spherical)
+    log_prob = _estimate_log_gaussian_prob_spherical(X, means, precs_spherical)
     assert_array_almost_equal(log_prob, log_prob_naive)
 
 # skip tests on weighted_log_probabilities, log_weights
@@ -463,33 +490,33 @@ def test_gaussian_mixture_estimate_log_prob_resp():
     n_components = rand_data.n_components
 
     X = rng.rand(n_samples, n_features)
-    for cov_type in COVARIANCE_TYPE:
+    for covar_type in COVARIANCE_TYPE:
         weights = rand_data.weights
         means = rand_data.means
-        covariances = rand_data.covariances[cov_type]
+        precisions = rand_data.precisions[covar_type]
         g = GaussianMixture(n_components=n_components, random_state=rng,
                             weights_init=weights, means_init=means,
-                            covariances_init=covariances,
-                            covariance_type=cov_type)
+                            precisions_init=precisions,
+                            covariance_type=covar_type)
         g.fit(X)
         resp = g.predict_proba(X)
         assert_array_almost_equal(resp.sum(axis=1), np.ones(n_samples))
         assert_array_equal(g.weights_init, weights)
         assert_array_equal(g.means_init, means)
-        assert_array_equal(g.covariances_init, covariances)
+        assert_array_equal(g.precisions_init, precisions)
 
 
 def test_gaussian_mixture_predict_predict_proba():
     rng = np.random.RandomState(0)
     rand_data = RandomData(rng)
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
         Y = rand_data.Y
         g = GaussianMixture(n_components=rand_data.n_components,
                             random_state=rng, weights_init=rand_data.weights,
                             means_init=rand_data.means,
-                            covariances_init=rand_data.covariances[cov_type],
-                            covariance_type=cov_type)
+                            precisions_init=rand_data.precisions[covar_type],
+                            covariance_type=covar_type)
 
         # Check a warning message arrive if we don't do fit
         assert_raise_message(NotFittedError,
@@ -511,12 +538,13 @@ def test_gaussian_mixture_fit():
     n_features = rand_data.n_features
     n_components = rand_data.n_components
 
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
-        g = GaussianMixture(n_components=n_components, n_init=20, max_iter=100,
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
+        g = GaussianMixture(n_components=n_components, n_init=20,
                             reg_covar=0, random_state=rng,
-                            covariance_type=cov_type)
+                            covariance_type=covar_type)
         g.fit(X)
+
         # needs more data to pass the test with rtol=1e-7
         assert_allclose(np.sort(g.weights_), np.sort(rand_data.weights),
                         rtol=0.1, atol=1e-2)
@@ -526,28 +554,29 @@ def test_gaussian_mixture_fit():
         assert_allclose(g.means_[arg_idx1], rand_data.means[arg_idx2],
                         rtol=0.1, atol=1e-2)
 
-        if cov_type == 'spherical':
-            cov_pred = np.array([np.eye(n_features) * c
-                                 for c in g.covariances_])
-            cov_test = np.array([np.eye(n_features) * c for c in
-                                 rand_data.covariances['spherical']])
-        elif cov_type == 'diag':
-            cov_pred = np.array([np.diag(d) for d in g.covariances_])
-            cov_test = np.array([np.diag(d) for d in
-                                 rand_data.covariances['diag']])
-        elif cov_type == 'tied':
-            cov_pred = np.array([g.covariances_] * n_components)
-            cov_test = np.array([rand_data.covariances['tied']] * n_components)
-        elif cov_type == 'full':
-            cov_pred = g.covariances_
-            cov_test = rand_data.covariances['full']
-        arg_idx1 = np.trace(cov_pred, axis1=1, axis2=2).argsort()
-        arg_idx2 = np.trace(cov_test, axis1=1, axis2=2).argsort()
+        if covar_type == 'full':
+            prec_pred = g.precisions_
+            prec_test = rand_data.precisions['full']
+        elif covar_type == 'tied':
+            prec_pred = np.array([g.precisions_] * n_components)
+            prec_test = np.array([rand_data.precisions['tied']] * n_components)
+        elif covar_type == 'spherical':
+            prec_pred = np.array([np.eye(n_features) * c
+                                 for c in g.precisions_])
+            prec_test = np.array([np.eye(n_features) * c for c in
+                                 rand_data.precisions['spherical']])
+        elif covar_type == 'diag':
+            prec_pred = np.array([np.diag(d) for d in g.precisions_])
+            prec_test = np.array([np.diag(d) for d in
+                                 rand_data.precisions['diag']])
+
+        arg_idx1 = np.trace(prec_pred, axis1=1, axis2=2).argsort()
+        arg_idx2 = np.trace(prec_test, axis1=1, axis2=2).argsort()
         for k, h in zip(arg_idx1, arg_idx2):
             ecov = EmpiricalCovariance()
-            ecov.covariance_ = cov_test[h]
+            ecov.covariance_ = prec_test[h]
             # the accuracy depends on the number of data and randomness, rng
-            assert_allclose(ecov.error_norm(cov_pred[k]), 0, atol=0.1)
+            assert_allclose(ecov.error_norm(prec_pred[k]), 0, atol=0.1)
 
 
 def test_gaussian_mixture_fit_best_params():
@@ -555,19 +584,18 @@ def test_gaussian_mixture_fit_best_params():
     rand_data = RandomData(rng)
     n_components = rand_data.n_components
     n_init = 10
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
-        g = GaussianMixture(n_components=n_components, n_init=1,
-                            max_iter=100, reg_covar=0, random_state=rng,
-                            covariance_type=cov_type)
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
+        g = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0,
+                            random_state=rng, covariance_type=covar_type)
         ll = []
         for _ in range(n_init):
             g.fit(X)
             ll.append(g.score(X))
         ll = np.array(ll)
         g_best = GaussianMixture(n_components=n_components,
-                                 n_init=n_init, max_iter=100, reg_covar=0,
-                                 random_state=rng, covariance_type=cov_type)
+                                 n_init=n_init, reg_covar=0, random_state=rng,
+                                 covariance_type=covar_type)
         g_best.fit(X)
         assert_almost_equal(ll.min(), g_best.score(X))
 
@@ -577,11 +605,11 @@ def test_gaussian_mixture_fit_convergence_warning():
     rand_data = RandomData(rng, scale=1)
     n_components = rand_data.n_components
     max_iter = 1
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
         g = GaussianMixture(n_components=n_components, n_init=1,
                             max_iter=max_iter, reg_covar=0, random_state=rng,
-                            covariance_type=cov_type)
+                            covariance_type=covar_type)
         assert_warns_message(ConvergenceWarning,
                              'Initialization %d did not converged. '
                              'Try different init parameters, '
@@ -659,14 +687,14 @@ def test_gaussian_mixture_verbose():
     rng = np.random.RandomState(0)
     rand_data = RandomData(rng)
     n_components = rand_data.n_components
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
-        g = GaussianMixture(n_components=n_components, n_init=1,
-                            max_iter=100, reg_covar=0, random_state=rng,
-                            covariance_type=cov_type, verbose=1)
-        h = GaussianMixture(n_components=n_components, n_init=1,
-                            max_iter=100, reg_covar=0, random_state=rng,
-                            covariance_type=cov_type, verbose=2)
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
+        g = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0,
+                            random_state=rng, covariance_type=covar_type,
+                            verbose=1)
+        h = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0,
+                            random_state=rng, covariance_type=covar_type,
+                            verbose=2)
         old_stdout = sys.stdout
         sys.stdout = StringIO()
         try:
@@ -699,7 +727,7 @@ def test_warm_start():
 
     assert_almost_equal(g.weights_, h.weights_)
     assert_almost_equal(g.means_, h.means_)
-    assert_almost_equal(g.covariances_, h.covariances_)
+    assert_almost_equal(g.precisions_, h.precisions_)
     assert_greater(score2, score1)
 
     # Assert that by using warm_start we can converge to a good solution
@@ -720,16 +748,16 @@ def test_warm_start():
 
 
 def test_score():
-    cov_type = 'full'
+    covar_type = 'full'
     rng = np.random.RandomState(0)
     rand_data = RandomData(rng, scale=7)
     n_components = rand_data.n_components
-    X = rand_data.X[cov_type]
+    X = rand_data.X[covar_type]
 
     # Check the error message if we don't call fit
     gmm1 = GaussianMixture(n_components=n_components, n_init=1,
                            max_iter=1, reg_covar=0, random_state=rng,
-                           covariance_type=cov_type)
+                           covariance_type=covar_type)
     assert_raise_message(NotFittedError,
                          "This GaussianMixture instance is not fitted "
                          "yet. Call 'fit' with appropriate arguments "
@@ -744,23 +772,22 @@ def test_score():
     assert_almost_equal(gmm_score, gmm_score_proba)
 
     # Check if the score increase
-    gmm2 = GaussianMixture(n_components=n_components, n_init=1,
-                           max_iter=1000, reg_covar=0, random_state=rng,
-                           covariance_type=cov_type).fit(X)
+    gmm2 = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0,
+                           random_state=rng,
+                           covariance_type=covar_type).fit(X)
     assert_greater(gmm2.score(X), gmm1.score(X))
 
 
 def test_score_samples():
-    cov_type = 'full'
+    covar_type = 'full'
     rng = np.random.RandomState(0)
     rand_data = RandomData(rng, scale=7)
     n_components = rand_data.n_components
-    X = rand_data.X[cov_type]
+    X = rand_data.X[covar_type]
 
     # Check the error message if we don't call fit
-    gmm = GaussianMixture(n_components=n_components, n_init=1,
-                          max_iter=100, reg_covar=0, random_state=rng,
-                          covariance_type=cov_type)
+    gmm = GaussianMixture(n_components=n_components, n_init=1, reg_covar=0,
+                          random_state=rng, covariance_type=covar_type)
     assert_raise_message(NotFittedError,
                          "This GaussianMixture instance is not fitted "
                          "yet. Call 'fit' with appropriate arguments "
@@ -777,10 +804,10 @@ def test_monotonic_likelihood():
     rand_data = RandomData(rng, scale=7)
     n_components = rand_data.n_components
 
-    for cov_type in COVARIANCE_TYPE:
-        X = rand_data.X[cov_type]
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
         gmm = GaussianMixture(n_components=n_components,
-                              covariance_type=cov_type, reg_covar=0,
+                              covariance_type=covar_type, reg_covar=0,
                               warm_start=True, max_iter=1, random_state=rng,
                               tol=1e-7)
         current_log_likelihood = -np.infty
@@ -810,9 +837,9 @@ def test_regularisation():
     X = np.vstack((np.ones((n_samples // 2, n_features)),
                    np.zeros((n_samples // 2, n_features))))
 
-    for cov_type in COVARIANCE_TYPE:
-        gmm = GaussianMixture(n_components=n_samples, covariance_type=cov_type,
-                              reg_covar=0, random_state=rng)
+    for covar_type in COVARIANCE_TYPE:
+        gmm = GaussianMixture(n_components=n_samples, reg_covar=0,
+                              covariance_type=covar_type, random_state=rng)
 
         with warnings.catch_warnings():
             warnings.simplefilter("ignore", RuntimeWarning)
@@ -823,3 +850,25 @@ def test_regularisation():
                                  "or increase reg_covar.", gmm.fit, X)
 
             gmm.set_params(reg_covar=1e-6).fit(X)
+
+
+def test_property():
+    rng = np.random.RandomState(0)
+    rand_data = RandomData(rng, scale=7)
+    n_components = rand_data.n_components
+
+    for covar_type in COVARIANCE_TYPE:
+        X = rand_data.X[covar_type]
+        gmm = GaussianMixture(n_components=n_components,
+                              covariance_type=covar_type, random_state=rng)
+        gmm.fit(X)
+        print(covar_type)
+        if covar_type is 'full':
+            for prec, covar in zip(gmm.precisions_, gmm.covariances_):
+
+                assert_array_almost_equal(linalg.inv(prec), covar)
+        elif covar_type is 'tied':
+            assert_array_almost_equal(linalg.inv(gmm.precisions_),
+                                      gmm.covariances_)
+        else:
+            assert_array_almost_equal(gmm.precisions_, 1. / gmm.covariances_)