diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py
index 501736a46da22135976e0f29e7b66fee374c2b4d..31773855f261079ef9bf24432afb84af6cd84a2c 100644
--- a/examples/mixture/plot_gmm_sin.py
+++ b/examples/mixture/plot_gmm_sin.py
@@ -41,9 +41,9 @@ color_iter = itertools.cycle(['r', 'g', 'b', 'c', 'm'])
 
 
 for i, (clf, title) in enumerate([
-        (mixture.GMM(n_components=10, covariance_type='diag'), \
+        (mixture.GMM(n_components=10, covariance_type='full'), \
              "Expectation-maximization"),
-        (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=0.01),
+        (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01),
          "Dirichlet Process,alpha=0.01"),
         (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100.),
          "Dirichlet Process,alpha=100.")
diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py
index e5066a6e84af1ee99a48b3b956d998049bc49b7d..1bc3a2070a65789f6da78ef8910a44a499c2f42d 100644
--- a/sklearn/mixture/dpgmm.py
+++ b/sklearn/mixture/dpgmm.py
@@ -43,7 +43,7 @@ def log_normalize(v, axis=0):
     return np.swapaxes(v, 0, axis)
 
 
-def detlog_wishart(a, b, detB, n_features):
+def wishart_log_det(a, b, detB, n_features):
     """Expected value of the log of the determinant of a Wishart
 
     The expected value of the logarithm of the determinant of a
@@ -66,65 +66,40 @@ def wishart_logz(v, s, dets, n_features):
 
 ##############################################################################
 # Variational bound on the log likelihood of each class
-
-def _bound_state_loglik_spherical(X, initial_bound, bound_prec, precs, means):
-    n_components, n_features = means.shape
-    n_samples = X.shape[0]
-    bound = np.empty((n_samples, n_components))
-    bound[:] = bound_prec + initial_bound
-    for k in xrange(n_components):
-        bound[:, k] -= 0.5 * precs[k] * (((X - means[k]) ** 2).sum(axis=-1)
-                                         + n_features)
-    return bound
-
-
-def _bound_state_loglik_diag(X, initial_bound, bound_prec, precs, means):
-    n_components, n_features = means.shape
-    n_samples = X.shape[0]
-    bound = np.empty((n_samples, n_components))
-    bound[:] = bound_prec + initial_bound
-    for k in xrange(n_components):
-        d = X - means[k]
-        d **= 2
-        bound[:, k] -= 0.5 * np.sum(d * precs[k], axis=1)
-    return bound
-
-
-def _bound_state_loglik_tied(X, initial_bound, bound_prec, precs, means):
-    n_components, n_features = means.shape
-    n_samples = X.shape[0]
-    bound = np.empty((n_samples, n_components))
-    bound[:] = bound_prec + initial_bound
-    # Transform the data to be able to apply standard Euclidean distance,
-    # rather than Mahlanobis distance
-    sqrt_cov = linalg.cholesky(precs)
-    means = np.dot(means, sqrt_cov.T)
-    X = np.dot(X, sqrt_cov.T)
-    bound -= 0.5 * euclidean_distances(X, means, squared=True)
-    return bound
+##############################################################################
 
 
-def _bound_state_loglik_full(X, initial_bound, bound_prec, precs, means):
+def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type):
+    """Update the bound with likelihood terms, for standard covariance types"""
     n_components, n_features = means.shape
     n_samples = X.shape[0]
     bound = np.empty((n_samples, n_components))
-    bound[:] = bound_prec + initial_bound
-    for k in xrange(n_components):
-        d = X - means[k]
-        sqrt_cov = linalg.cholesky(precs[k])
-        d = np.dot(d, sqrt_cov.T)
-        d **= 2
-        bound[:, k] -= 0.5 * d.sum(axis=-1)
+    bound[:] = initial_bound
+    if covariance_type == 'diag':
+        for k in xrange(n_components):
+            d = X - means[k]
+            bound[:, k] -= 0.5 * np.sum(d * d * precs[k], axis=1)
+    elif covariance_type == 'spherical':
+        for k in xrange(n_components):
+            bound[:, k] -= 0.5 * precs[k] * (((X - means[k]) ** 2).sum(axis=-1)
+                                             + n_features)
+    elif covariance_type == 'tied':
+        sqrt_cov = linalg.cholesky(precs)
+        means = np.dot(means, sqrt_cov.T)
+        X = np.dot(X, sqrt_cov.T)
+        bound -= 0.5 * euclidean_distances(X, means, squared=True)
+    elif covariance_type == 'full':
+        for k in xrange(n_components):
+            d = X - means[k]
+            # not: choleksy is useless here
+            sqrt_cov = linalg.cholesky(precs[k])
+            d = np.dot(d, sqrt_cov.T)
+            d **= 2
+            bound[:, k] -= 0.5 * d.sum(axis=-1)
+    
     return bound
 
 
-_BOUND_STATE_LOGLIK_DICT = dict(
-    spherical=_bound_state_loglik_spherical,
-    diag=_bound_state_loglik_diag,
-    tied=_bound_state_loglik_tied,
-    full=_bound_state_loglik_full)
-
-
 class DPGMM(GMM):
     """Variational Inference for the Infinite Gaussian Mixture Model.
 
@@ -260,26 +235,23 @@ class DPGMM(GMM):
         if X.ndim == 1:
             X = X[:, np.newaxis]
         z = np.zeros((X.shape[0], self.n_components))
-        sd = digamma(self._gamma.T[1] + self._gamma.T[2])
-        dgamma1 = digamma(self._gamma.T[1]) - sd
+        sd = digamma(self.gamma_.T[1] + self.gamma_.T[2])
+        dgamma1 = digamma(self.gamma_.T[1]) - sd
         dgamma2 = np.zeros(self.n_components)
-        dgamma2[0] = digamma(self._gamma[0, 2]) - digamma(self._gamma[0, 1] +
-                                                          self._gamma[0, 2])
+        dgamma2[0] = digamma(self.gamma_[0, 2]) - digamma(self.gamma_[0, 1] +
+                                                          self.gamma_[0, 2])
         for j in xrange(1, self.n_components):
-            dgamma2[j] = dgamma2[j - 1] + digamma(self._gamma[j - 1, 2])
+            dgamma2[j] = dgamma2[j - 1] + digamma(self.gamma_[j - 1, 2])
             dgamma2[j] -= sd[j - 1]
         dgamma = dgamma1 + dgamma2
         # Free memory and developers cognitive load:
         del dgamma1, dgamma2, sd
 
-        try:
-            _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type]
-        except KeyError:
+        if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']:
             raise NotImplementedError("This ctype is not implemented: %s"
                                       % self.covariance_type)
-
-        p = _bound_state_loglik(X, self._initial_bound,
-                        self._bound_prec, self.precs_, self.means_)
+        p = _bound_state_log_lik(X, self._initial_bound + self.bound_prec_, 
+                                 self.precs_, self.means_, self.covariance_type)
         z = p + dgamma
         z = log_normalize(z, axis=-1)
         bound = np.sum(z * p, axis=-1)
@@ -288,11 +260,11 @@ class DPGMM(GMM):
     def _update_concentration(self, z):
         """Update the concentration parameters for each cluster"""
         sz = np.sum(z, axis=0)
-        self._gamma.T[1] = 1. + sz
-        self._gamma.T[2].fill(0)
+        self.gamma_.T[1] = 1. + sz
+        self.gamma_.T[2].fill(0)
         for i in xrange(self.n_components - 2, -1, -1):
-            self._gamma[i, 2] = self._gamma[i + 1, 2] + sz[i]
-        self._gamma.T[2] += self.alpha
+            self.gamma_[i, 2] = self.gamma_[i + 1, 2] + sz[i]
+        self.gamma_.T[2] += self.alpha
 
     def _update_means(self, X, z):
         """Update the variational distributions for the means"""
@@ -317,63 +289,63 @@ class DPGMM(GMM):
         """Update the variational distributions for the precisions"""
         n_features = X.shape[1]
         if self.covariance_type == 'spherical':
-            self._a = 0.5 * n_features * np.sum(z, axis=0)
+            self.dof_ = 0.5 * n_features * np.sum(z, axis=0)
             for k in xrange(self.n_components):
                 # XXX: how to avoid this huge temporary matrix in memory
                 dif = (X - self.means_[k])
-                self._b[k] = 1.
+                self.scale_[k] = 1.
                 d = np.sum(dif * dif, axis=1)
-                self._b[k] += 0.5 * np.sum(z.T[k] * (d + n_features))
-                self._bound_prec[k] = (
+                self.scale_[k] += 0.5 * np.sum(z.T[k] * (d + n_features))
+                self.bound_prec_[k] = (
                     0.5 * n_features * (
-                        digamma(self._a[k]) - np.log(self._b[k])))
-            self.precs_ = self._a / self._b
+                        digamma(self.dof_[k]) - np.log(self.scale_[k])))
+            self.precs_ = self.dof_ / self.scale_
 
         elif self.covariance_type == 'diag':
             for k in xrange(self.n_components):
-                self._a[k].fill(1. + 0.5 * np.sum(z.T[k], axis=0))
+                self.dof_[k].fill(1. + 0.5 * np.sum(z.T[k], axis=0))
                 ddif = (X - self.means_[k])  # see comment above
                 for d in xrange(n_features):
-                    self._b[k, d] = 1.
+                    self.scale_[k, d] = 1.
                     dd = ddif.T[d] * ddif.T[d]
-                    self._b[k, d] += 0.5 * np.sum(z.T[k] * (dd + 1))
-                self.precs_[k] = self._a[k] / self._b[k]
-                self._bound_prec[k] = 0.5 * np.sum(digamma(self._a[k])
-                                                    - np.log(self._b[k]))
-                self._bound_prec[k] -= 0.5 * np.sum(self.precs_[k])
+                    self.scale_[k, d] += 0.5 * np.sum(z.T[k] * (dd + 1))
+                self.precs_[k] = self.dof_[k] / self.scale_[k]
+                self.bound_prec_[k] = 0.5 * np.sum(digamma(self.dof_[k])
+                                                    - np.log(self.scale_[k]))
+                self.bound_prec_[k] -= 0.5 * np.sum(self.precs_[k])
 
         elif self.covariance_type == 'tied':
-            self._a = 2 + X.shape[0] + n_features
-            self._B = (X.shape[0] + 1) * np.identity(n_features)
+            self.dof_ = 2 + X.shape[0] + n_features
+            self.scale_ = (X.shape[0] + 1) * np.identity(n_features)
             for i in xrange(X.shape[0]):
                 for k in xrange(self.n_components):
                     dif = X[i] - self.means_[k]
-                    self._B += z[i, k] * np.dot(dif.reshape((-1, 1)),
+                    self.scale_ += z[i, k] * np.dot(dif.reshape((-1, 1)),
                                                 dif.reshape((1, -1)))
-            self._B = linalg.pinv(self._B)
-            self.precs_ = self._a * self._B
-            self._detB = linalg.det(self._B)
-            self._bound_prec = 0.5 * detlog_wishart(
-                self._a, self._B, self._detB, n_features)
-            self._bound_prec -= 0.5 * self._a * np.trace(self._B)
+            self.scale_ = linalg.pinv(self.scale_)
+            self.precs_ = self.dof_ * self.scale_
+            self.det_scale_ = linalg.det(self.scale_)
+            self.bound_prec_ = 0.5 * wishart_log_det(
+                self.dof_, self.scale_, self.det_scale_, n_features)
+            self.bound_prec_ -= 0.5 * self.dof_ * np.trace(self.scale_)
 
         elif self.covariance_type == 'full':
             for k in xrange(self.n_components):
                 T = np.sum(z.T[k])
-                self._a[k] = 2 + T + n_features
-                self._B[k] = (T + 1) * np.identity(n_features)
+                self.dof_[k] = 2 + T + n_features
+                self.scale_[k] = (T + 1) * np.identity(n_features)
                 for i in xrange(X.shape[0]):
                     dif = X[i] - self.means_[k]
-                    self._B[k] += z[i, k] * np.dot(dif.reshape((-1, 1)),
+                    self.scale_[k] += z[i, k] * np.dot(dif.reshape((-1, 1)),
                                                    dif.reshape((1, -1)))
-                self._B[k] = linalg.pinv(self._B[k])
-                self.precs_[k] = self._a[k] * self._B[k]
-                self._detB[k] = linalg.det(self._B[k])
-                self._bound_prec[k] = 0.5 * detlog_wishart(self._a[k],
-                                                           self._B[k],
-                                                           self._detB[k],
+                self.scale_[k] = linalg.pinv(self.scale_[k])
+                self.precs_[k] = self.dof_[k] * self.scale_[k]
+                self.det_scale_[k] = linalg.det(self.scale_[k])
+                self.bound_prec_[k] = 0.5 * wishart_log_det(self.dof_[k],
+                                                           self.scale_[k],
+                                                           self.det_scale_[k],
                                                            n_features)
-                self._bound_prec[k] -= 0.5 * self._a[k] * np.trace(self._B[k])
+                self.bound_prec_[k] -= 0.5 * self.dof_[k] * np.trace(self.scale_[k])
 
     def _monitor(self, X, z, n, end=False):
         """Monitor the lower bound during iteration
@@ -385,7 +357,7 @@ class DPGMM(GMM):
         if self.verbose:
             print "Bound after updating %8s: %f" % (n, self.lower_bound(X, z))
             if end == True:
-                print "Cluster proportions:", self._gamma.T[1]
+                print "Cluster proportions:", self.gamma_.T[1]
                 print "covariance_type:", self._covariance_type
 
     def _do_mstep(self, X, z, params):
@@ -404,24 +376,24 @@ class DPGMM(GMM):
 
     def _initialize_gamma(self):
         "Initializes the concentration parameters"
-        self._gamma = self.alpha * np.ones((self.n_components, 3))
+        self.gamma_ = self.alpha * np.ones((self.n_components, 3))
 
     def _bound_concentration(self):
         "The variational lower bound for the concentration parameter."
         logprior = 0.
         for k in xrange(self.n_components):
             logprior = gammaln(self.alpha)
-            logprior += (self.alpha - 1) * (digamma(self._gamma[k, 2]) -
-                                            digamma(self._gamma[k, 1] +
-                                                    self._gamma[k, 2]))
-            logprior += -gammaln(self._gamma[k, 1] + self._gamma[k, 2])
-            logprior += gammaln(self._gamma[k, 1]) + gammaln(self._gamma[k, 2])
-            logprior -= (self._gamma[k, 1] - 1) * (digamma(self._gamma[k, 1]) -
-                                                   digamma(self._gamma[k, 1] +
-                                                           self._gamma[k, 2]))
-            logprior -= (self._gamma[k, 2] - 1) * (digamma(self._gamma[k, 2]) -
-                                                   digamma(self._gamma[k, 1] +
-                                                           self._gamma[k, 2]))
+            logprior += (self.alpha - 1) * (digamma(self.gamma_[k, 2]) -
+                                            digamma(self.gamma_[k, 1] +
+                                                    self.gamma_[k, 2]))
+            logprior += -gammaln(self.gamma_[k, 1] + self.gamma_[k, 2])
+            logprior += gammaln(self.gamma_[k, 1]) + gammaln(self.gamma_[k, 2])
+            logprior -= (self.gamma_[k, 1] - 1) * (digamma(self.gamma_[k, 1]) -
+                                                   digamma(self.gamma_[k, 1] +
+                                                           self.gamma_[k, 2]))
+            logprior -= (self.gamma_[k, 2] - 1) * (digamma(self.gamma_[k, 2]) -
+                                                   digamma(self.gamma_[k, 1] +
+                                                           self.gamma_[k, 2]))
         return logprior
 
     def _bound_means(self):
@@ -437,7 +409,7 @@ class DPGMM(GMM):
         logprior -= wishart_logz(n_features,
                                  np.identity(n_features),
                                  1, n_features)
-        logprior += 0.5 * (a - 1) * detlog_wishart(a, B, detB, n_features)
+        logprior += 0.5 * (a - 1) * wishart_log_det(a, B, detB, n_features)
         logprior += 0.5 * a * np.trace(B)
         return logprior
 
@@ -445,29 +417,29 @@ class DPGMM(GMM):
         logprior = 0.
         if self.covariance_type == 'spherical':
             for k in xrange(self.n_components):
-                logprior += gammaln(self._a[k])
-                logprior -= (self._a[k] - 1) * digamma(max(0.5, self._a[k]))
-                logprior += - np.log(self._b[k]) + self._a[k] - self.precs_[k]
+                logprior += gammaln(self.dof_[k])
+                logprior -= (self.dof_[k] - 1) * digamma(max(0.5, self.dof_[k]))
+                logprior += - np.log(self.scale_[k]) + self.dof_[k] - self.precs_[k]
         elif self.covariance_type == 'diag':
             for k in xrange(self.n_components):
                 for d in xrange(self.means.shape[1]):
-                    logprior += gammaln(self._a[k, d])
-                    logprior -= (self._a[k, d] - 1) * digamma(self._a[k, d])
-                    logprior -= np.log(self._b[k, d])
-                    logprior += self._a[k, d] - self.precs_[k, d]
+                    logprior += gammaln(self.dof_[k, d])
+                    logprior -= (self.dof_[k, d] - 1) * digamma(self.dof_[k, d])
+                    logprior -= np.log(self.scale_[k, d])
+                    logprior += self.dof_[k, d] - self.precs_[k, d]
         elif self.covariance_type == 'tied':
-            logprior += self._bound_wishart(self._a, self._B, self._detB)
+            logprior += self._bound_wishart(self.dof_, self.scale_, self.det_scale_)
         elif self.covariance_type == 'full':
             for k in xrange(self.n_components):
-                logprior += self._bound_wishart(self._a[k],
-                                                self._B[k],
-                                                self._detB[k])
+                logprior += self._bound_wishart(self.dof_[k],
+                                                self.scale_[k],
+                                                self.det_scale_[k])
         return logprior
 
     def _bound_proportions(self, z):
-        dg12 = digamma(self._gamma.T[1] + self._gamma.T[2])
-        dg1 = digamma(self._gamma.T[1]) - dg12
-        dg2 = digamma(self._gamma.T[2]) - dg12
+        dg12 = digamma(self.gamma_.T[1] + self.gamma_.T[2])
+        dg1 = digamma(self.gamma_.T[1]) - dg12
+        dg2 = digamma(self.gamma_.T[2]) - dg12
 
         cz = np.cumsum(z[:, ::-1], axis=-1)[:, -2::-1]
         logprior = np.sum(cz * dg2[:-1]) + np.sum(z * dg1)
@@ -484,17 +456,16 @@ class DPGMM(GMM):
         return logprior
 
     def lower_bound(self, X, z):
-        try:
-            _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type]
-        except KeyError:
+        if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']:
             raise NotImplementedError("This ctype is not implemented: %s"
                                       % self.covariance_type)
+
         X = np.asarray(X)
         if X.ndim == 1:
             X = X[:, np.newaxis]
-        c = np.sum(z * _bound_state_loglik(
-                X, self._initial_bound, self._bound_prec, self.precs_, 
-                self.means_))
+        c = np.sum(z * _bound_state_log_lik(
+                X, self._initial_bound + self.bound_prec_, 
+                self.precs_, self.means_, self.covariance_type))
 
         return c + self._logprior(z)
 
@@ -554,44 +525,47 @@ class DPGMM(GMM):
 
         if 'c' in init_params or not hasattr(self, 'covars'):
             if self.covariance_type == 'spherical':
-                self._a = np.ones(self.n_components)
-                self._b = np.ones(self.n_components)
+                self.dof_ = np.ones(self.n_components)
+                self.scale_ = np.ones(self.n_components)
                 self.precs_ = np.ones(self.n_components)
-                self._bound_prec = (0.5 * n_features *
-                                     (digamma(self._a) -
-                                      np.log(self._b)))
+                self.bound_prec_ = (
+                    0.5 * n_features * (digamma(self.dof_) 
+                                        - np.log(self.scale_)))
             elif self.covariance_type == 'diag':
-                self._a = 1 + 0.5 * n_features
-                self._a *= np.ones((self.n_components, n_features))
-                self._b = np.ones((self.n_components, n_features))
+                self.dof_ = 1 + 0.5 * n_features
+                self.dof_ *= np.ones((self.n_components, n_features))
+                self.scale_ = np.ones((self.n_components, n_features))
                 self.precs_ = np.ones((self.n_components, n_features))
-                self._bound_prec = np.zeros(self.n_components)
+                self.bound_prec_ = np.zeros(self.n_components)
                 for k in xrange(self.n_components):
-                    self._bound_prec[k] = 0.5 * np.sum(digamma(self._a[k])
-                                                        - np.log(self._b[k]))
-                    self._bound_prec[k] -= 0.5 * np.sum(self.precs_[k])
+                    self.bound_prec_[k] = (
+                        0.5 * np.sum(digamma(self.dof_[k]) 
+                                     - np.log(self.scale_[k])))
+                    self.bound_prec_[k] -= 0.5 * np.sum(self.precs_[k])
             elif self.covariance_type == 'tied':
-                self._a = 1.
-                self._B = np.identity(n_features)
+                self.dof_ = 1.
+                self.scale_ = np.identity(n_features)
                 self.precs_ = np.identity(n_features)
-                self._detB = 1.
-                self._bound_prec = 0.5 * detlog_wishart(
-                    self._a, self._B, self._detB, n_features)
-                self._bound_prec -= 0.5 * self._a * np.trace(self._B)
+                self.det_scale_ = 1.
+                self.bound_prec_ = 0.5 * wishart_log_det(
+                    self.dof_, self.scale_, self.det_scale_, n_features)
+                self.bound_prec_ -= 0.5 * self.dof_ * np.trace(self.scale_)
             elif self.covariance_type == 'full':
-                self._a = (1 + self.n_components + X.shape[0])
-                self._a *= np.ones(self.n_components)
-                self._B = [2 * np.identity(n_features)
+                self.dof_ = (1 + self.n_components + X.shape[0])
+                self.dof_ *= np.ones(self.n_components)
+                self.scale_ = [2 * np.identity(n_features)
                            for i in xrange(self.n_components)]
                 self.precs_ = [np.identity(n_features)
                                 for i in xrange(self.n_components)]
-                self._detB = np.ones(self.n_components)
-                self._bound_prec = np.zeros(self.n_components)
+                self.det_scale_ = np.ones(self.n_components)
+                self.bound_prec_ = np.zeros(self.n_components)
                 for k in xrange(self.n_components):
-                    self._bound_prec[k] = detlog_wishart(
-                        self._a[k], self._B[k], self._detB[k], n_features)
-                    self._bound_prec[k] -= self._a[k] * np.trace(self._B[k])
-                    self._bound_prec[k] *= 0.5
+                    self.bound_prec_[k] = wishart_log_det(
+                        self.dof_[k],self.scale_[k], self.det_scale_[k], 
+                        n_features)
+                    self.bound_prec_[k] -= (self.dof_[k] * 
+                                            np.trace(self.scale_[k]))
+                    self.bound_prec_[k] *= 0.5
 
         logprob = []
         # reset self.converged_ to False
@@ -677,8 +651,8 @@ class VBGMM(DPGMM):
                  random_state=None, thresh=1e-2, verbose=False,
                  min_covar=None):
         super(VBGMM, self).__init__(
-            n_components, covariance_type, random_state=random_state, thresh=thresh,
-            verbose=verbose, min_covar=min_covar)
+            n_components, covariance_type, random_state=random_state, 
+            thresh=thresh, verbose=verbose, min_covar=min_covar)
         self.alpha = float(alpha) / n_components
 
     def eval(self, X):
@@ -711,15 +685,15 @@ class VBGMM(DPGMM):
         z = np.zeros((X.shape[0], self.n_components))
         p = np.zeros(self.n_components)
         bound = np.zeros(X.shape[0])
-        dg = digamma(self._gamma) - digamma(np.sum(self._gamma))
-        try:
-            _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type]
-        except KeyError:
+        dg = digamma(self.gamma_) - digamma(np.sum(self.gamma_))
+
+        if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']:
             raise NotImplementedError("This ctype is not implemented: %s"
                                       % self.covariance_type)
-
-        p = _bound_state_loglik(X, self._initial_bound,
-                                self._bound_prec, self.precs_, self.means_)
+        p = _bound_state_log_lik(
+                X, self._initial_bound + self.bound_prec_, 
+                self.precs_, self.means_, self.covariance_type)
+ 
         z = p + dg
         z = log_normalize(z, axis=-1)
         bound = np.sum(z * p, axis=-1)
@@ -727,15 +701,15 @@ class VBGMM(DPGMM):
 
     def _update_concentration(self, z):
         for i in xrange(self.n_components):
-            self._gamma[i] = self.alpha + np.sum(z.T[i])
+            self.gamma_[i] = self.alpha + np.sum(z.T[i])
 
     def _initialize_gamma(self):
-        self._gamma = self.alpha * np.ones(self.n_components)
+        self.gamma_ = self.alpha * np.ones(self.n_components)
 
     def _bound_proportions(self, z):
         logprior = 0.
-        dg = digamma(self._gamma)
-        dg -= digamma(np.sum(self._gamma))
+        dg = digamma(self.gamma_)
+        dg -= digamma(np.sum(self.gamma_))
         logprior += np.sum(dg.reshape((-1, 1)) * z.T)
         z_non_zeros = z[z > np.finfo(np.float32).eps]
         logprior -= np.sum(z_non_zeros * np.log(z_non_zeros))
@@ -743,12 +717,12 @@ class VBGMM(DPGMM):
 
     def _bound_concentration(self):
         logprior = 0.
-        logprior = gammaln(np.sum(self._gamma)) - gammaln(self.n_components
+        logprior = gammaln(np.sum(self.gamma_)) - gammaln(self.n_components
                                                           * self.alpha)
-        logprior -= np.sum(gammaln(self._gamma) - gammaln(self.alpha))
-        sg = digamma(np.sum(self._gamma))
-        logprior += np.sum((self._gamma - self.alpha)
-                           * (digamma(self._gamma) - sg))
+        logprior -= np.sum(gammaln(self.gamma_) - gammaln(self.alpha))
+        sg = digamma(np.sum(self.gamma_))
+        logprior += np.sum((self.gamma_ - self.alpha)
+                           * (digamma(self.gamma_) - sg))
         return logprior
 
     def _monitor(self, X, z, n, end=False):
@@ -761,5 +735,5 @@ class VBGMM(DPGMM):
         if self.verbose:
             print "Bound after updating %8s: %f" % (n, self.lower_bound(X, z))
             if end == True:
-                print "Cluster proportions:", self._gamma
+                print "Cluster proportions:", self.gamma_
                 print "covariance_type:", self._covariance_type