diff --git a/examples/mixture/gmm_model_selection.py b/examples/mixture/gmm_model_selection.py
index 635add5e34a9ac1776f633a8f17a3b7534ec7a5d..8ea0ec61f2b5700b0b625a460602ce2f727376cd 100644
--- a/examples/mixture/gmm_model_selection.py
+++ b/examples/mixture/gmm_model_selection.py
@@ -64,7 +64,7 @@ spl.legend([b[0] for b in bars], cv_types)
 # Plot the winner
 splot = pl.subplot(2, 1, 2)
 Y_ = clf.predict(X)
-for i, (mean, covar, color) in enumerate(zip(clf.means, clf.covars,
+for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_,
                                              color_iter)):
     v, w = linalg.eigh(covar)
     if not np.any(Y_ == i):
diff --git a/examples/mixture/plot_gmm.py b/examples/mixture/plot_gmm.py
index eeebe2fa0835a386b567f88b28cd99732379ab35..2d97a3b1677415ed6c0ba29baa435b7e061a5964 100644
--- a/examples/mixture/plot_gmm.py
+++ b/examples/mixture/plot_gmm.py
@@ -52,8 +52,8 @@ for i, (clf, title) in enumerate([(gmm, 'GMM'),
                                   (dpgmm, 'Dirichlet Process GMM')]):
     splot = pl.subplot(2, 1, 1 + i)
     Y_ = clf.predict(X)
-    for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_,
-                                                 color_iter)):
+    for i, (mean, covar, color) in enumerate(zip(
+            clf._get_means(), clf._get_covars(), color_iter)):
         v, w = linalg.eigh(covar)
         u = w[0] / linalg.norm(w[0])
         # as the DP will not use every component it has access to
diff --git a/examples/mixture/plot_gmm_classifier.py b/examples/mixture/plot_gmm_classifier.py
index f7aa1585d7ddd369d7c28994e03d23e3880e4eb1..835df829f69a99fa64e09d608053d255fb80aa31 100644
--- a/examples/mixture/plot_gmm_classifier.py
+++ b/examples/mixture/plot_gmm_classifier.py
@@ -37,13 +37,13 @@ from sklearn.mixture import GMM
 
 def make_ellipses(gmm, ax):
     for n, color in enumerate('rgb'):
-        v, w = np.linalg.eigh(gmm.covars[n][:2, :2])
+        v, w = np.linalg.eigh(gmm._get_covars()[n][:2, :2])
         u = w[0] / np.linalg.norm(w[0])
         angle = np.arctan2(u[1], u[0])
         angle = 180 * angle / np.pi  # convert to degrees
         v *= 9
-        ell = mpl.patches.Ellipse(gmm.means[n, :2], v[0], v[1], 180 + angle,
-                                  color=color)
+        ell = mpl.patches.Ellipse(gmm._get_means()[n, :2], v[0], v[1], 
+                                  180 + angle, color=color)
         ell.set_clip_box(ax.bbox)
         ell.set_alpha(0.5)
         ax.add_artist(ell)
@@ -78,9 +78,9 @@ pl.subplots_adjust(bottom=.01, top=0.95, hspace=.15, wspace=.05,
 for index, (name, classifier) in enumerate(classifiers.iteritems()):
     # Since we have class labels for the training data, we can
     # initialize the GMM parameters in a supervised manner.
-    classifier.means = [X_train[y_train == i, :].mean(axis=0)
-                        for i in xrange(n_classes)]
-
+    classifier.means_ = np.array([X_train[y_train == i, :].mean(axis=0)
+                                  for i in xrange(n_classes)])
+    
     # Train the other parameters using the EM algorithm.
     classifier.fit(X_train, init_params='wc', n_iter=20)
 
diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py
index 31773855f261079ef9bf24432afb84af6cd84a2c..dc65dcadbb8643bd052fc0e3f9171c36ac0e187c 100644
--- a/examples/mixture/plot_gmm_sin.py
+++ b/examples/mixture/plot_gmm_sin.py
@@ -52,8 +52,8 @@ for i, (clf, title) in enumerate([
     clf.fit(X, n_iter=100)
     splot = pl.subplot(3, 1, 1 + i)
     Y_ = clf.predict(X)
-    for i, (mean, covar, color) in enumerate(zip(clf.means, clf.covars,
-                                                 color_iter)):
+    for i, (mean, covar, color) in enumerate(zip(
+            clf._get_means(), clf._get_covars(), color_iter)):
         v, w = linalg.eigh(covar)
         u = w[0] / linalg.norm(w[0])
         # as the DP will not use every component it has access to
diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py
index 12328b0e814a7c934ab53b550ee15630d5468420..23b132d6cc731e0db66d87b9c1ad2f9ddb0fedf5 100644
--- a/sklearn/mixture/dpgmm.py
+++ b/sklearn/mixture/dpgmm.py
@@ -57,8 +57,7 @@ def wishart_logz(v, s, dets, n_features):
     "The logarithm of the normalization constant for the wishart distribution"
     z = 0.
     z += 0.5 * v * n_features * np.log(2)
-    z += (0.25 * (n_features * (n_features - 1))
-          * np.log(np.pi))
+    z += (0.25 * (n_features * (n_features - 1)) * np.log(np.pi))
     z += 0.5 * v * np.log(dets)
     z += np.sum(gammaln(0.5 * (v - np.arange(n_features) + 1)))
     return z
@@ -499,7 +498,7 @@ class DPGMM(GMM):
         z = np.ones((X.shape[0], self.n_components))
         z /= self.n_components
 
-        self._initial_bound = -0.5 * n_features * np.log(2 * np.pi)
+        self._initial_bound = - 0.5 * n_features * np.log(2 * np.pi)
         self._initial_bound -= np.log(2 * np.pi * np.e)
 
         if (init_params != '') or not hasattr(self, 'gamma_'):
@@ -510,7 +509,7 @@ class DPGMM(GMM):
                 k=self.n_components, random_state=self.random_state
             ).fit(X).cluster_centers_[::-1]
 
-        if 'w' in init_params or not hasattr(self, 'log_weights_'): # fixme
+        if 'w' in init_params or not hasattr(self, 'weights_'):
             self._set_weights(np.tile(1.0 / self.n_components, 
                                       self.n_components))
 
diff --git a/sklearn/mixture/gmm.py b/sklearn/mixture/gmm.py
index 5388976323f61efcc77c43d51c1ffad9539ed23e..87166d5a41c94076814b5f638e245e76defefdd1 100644
--- a/sklearn/mixture/gmm.py
+++ b/sklearn/mixture/gmm.py
@@ -139,8 +139,8 @@ class GMM(BaseEstimator):
     covariance_type : string (read-only)
         String describing the type of covariance parameters used by
         the GMM.  Must be one of 'spherical', 'tied', 'diag', 'full'.
-    log_weights_ : array, shape (`n_components`,)
-        log of mixing weights for each mixture component.
+    weights_ : array, shape (`n_components`,)
+        Mixing weights for each mixture component.
     means_ : array, shape (`n_components`, `n_features`)
         Mean parameters for each mixture component.
     covars_ : array
@@ -212,8 +212,7 @@ class GMM(BaseEstimator):
         if not covariance_type in ['spherical', 'tied', 'diag', 'full']:
             raise ValueError('bad covariance_type: ' + str(covariance_type))
 
-        self.log_weights_ = - np.ones(self.n_components) * \
-            np.log(self.n_components)
+        self.weights_ = np.ones(self.n_components) / self.n_components
 
         # flag to indicate exit status of fit() method: converged (True) or
         # n_iter reached (False)
@@ -272,7 +271,7 @@ class GMM(BaseEstimator):
     def _get_weights(self):
         """Mixing weights for each mixture component.      
         array, shape ``(n_states,)``"""
-        return np.exp(self.log_weights_)
+        return self.weights_ 
 
     def _set_weights(self, weights):
         """Provide value for micture weights"""
@@ -280,7 +279,7 @@ class GMM(BaseEstimator):
             raise ValueError('weights must have length n_components')
         if not np.allclose(np.sum(weights), 1.0):
             raise ValueError('weights must sum to 1.0')
-        self.log_weights_ = np.log(np.asarray(weights).copy())
+        self.weights_ = np.asarray(weights).copy()
 
     def eval(self, X):
         """Evaluate the model on data
@@ -310,10 +309,10 @@ class GMM(BaseEstimator):
             return np.array([]), np.empty((0, self.n_components))
         if X.shape[1] != self.means_.shape[1]:
             raise ValueError('the shape of X  is not compatible with self')
-
+        
         lpr = (log_multivariate_normal_density(
                 X, self.means_, self.covars_, self._covariance_type)
-               + self.log_weights_)
+               + np.log(self.weights_))
         logprob = logsumexp(lpr, axis=1)
         responsibilities = np.exp(lpr - logprob[:, np.newaxis])
         return logprob, responsibilities
@@ -407,8 +406,7 @@ class GMM(BaseEstimator):
         if random_state is None:
             random_state = self.random_state
         random_state = check_random_state(random_state)
-        weight_pdf = np.exp(self.log_weights_)
-        weight_cdf = np.cumsum(weight_pdf)
+        weight_cdf = np.cumsum(self.weights_)
 
         X = np.empty((n_samples, self.means_.shape[1]))
         rand = random_state.rand(n_samples)
@@ -484,8 +482,8 @@ class GMM(BaseEstimator):
                     k=self.n_components).fit(X).cluster_centers_
 
             if 'w' in init_params or not hasattr(self, 'weights_'):
-                self._set_weights(np.tile(1.0 / self.n_components,
-                                          self.n_components))
+                self.weights_ = np.tile(1.0 / self.n_components,
+                                        self.n_components)
 
             if 'c' in init_params or not hasattr(self, 'covars_'):
                 cv = np.cov(X.T) + self.min_covar * np.eye(X.shape[1])
@@ -517,13 +515,13 @@ class GMM(BaseEstimator):
             if n_iter:
                 if log_likelihood[-1] > max_log_prob:
                     max_log_prob = log_likelihood[-1]
-                    best_params = {'weights': self._get_weights(),
+                    best_params = {'weights': self.weights_,
                                    'means': self.means_,
                                    'covars': self.covars_}
         if n_iter:
             self.covars_ = best_params['covars']
             self.means_ = best_params['means']
-            self._set_weights(best_params['weights'])
+            self.weights_ = best_params['weights']
         return self
 
     def _do_mstep(self, X, responsibilities, params, min_covar=0):
@@ -534,8 +532,8 @@ class GMM(BaseEstimator):
         inverse_weights = 1.0 / (weights[:, np.newaxis] + 10 * INF_EPS)
 
         if 'w' in params:
-            self.log_weights_ = np.log(
-                weights / (weights.sum() + 10 * INF_EPS) + INF_EPS)
+            self.weights_ = (weights / (weights.sum() + 10 * INF_EPS) + 
+                                 INF_EPS)            
         if 'm' in params:
             self.means_ = weighted_X_sum * inverse_weights
         if 'c' in params: