diff --git a/examples/mixture/plot_gmm_sin.py b/examples/mixture/plot_gmm_sin.py index 501736a46da22135976e0f29e7b66fee374c2b4d..31773855f261079ef9bf24432afb84af6cd84a2c 100644 --- a/examples/mixture/plot_gmm_sin.py +++ b/examples/mixture/plot_gmm_sin.py @@ -41,9 +41,9 @@ color_iter = itertools.cycle(['r', 'g', 'b', 'c', 'm']) for i, (clf, title) in enumerate([ - (mixture.GMM(n_components=10, covariance_type='diag'), \ + (mixture.GMM(n_components=10, covariance_type='full'), \ "Expectation-maximization"), - (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=0.01), + (mixture.DPGMM(n_components=10, covariance_type='full', alpha=0.01), "Dirichlet Process,alpha=0.01"), (mixture.DPGMM(n_components=10, covariance_type='diag', alpha=100.), "Dirichlet Process,alpha=100.") diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index e5066a6e84af1ee99a48b3b956d998049bc49b7d..1bc3a2070a65789f6da78ef8910a44a499c2f42d 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -43,7 +43,7 @@ def log_normalize(v, axis=0): return np.swapaxes(v, 0, axis) -def detlog_wishart(a, b, detB, n_features): +def wishart_log_det(a, b, detB, n_features): """Expected value of the log of the determinant of a Wishart The expected value of the logarithm of the determinant of a @@ -66,65 +66,40 @@ def wishart_logz(v, s, dets, n_features): ############################################################################## # Variational bound on the log likelihood of each class - -def _bound_state_loglik_spherical(X, initial_bound, bound_prec, precs, means): - n_components, n_features = means.shape - n_samples = X.shape[0] - bound = np.empty((n_samples, n_components)) - bound[:] = bound_prec + initial_bound - for k in xrange(n_components): - bound[:, k] -= 0.5 * precs[k] * (((X - means[k]) ** 2).sum(axis=-1) - + n_features) - return bound - - -def _bound_state_loglik_diag(X, initial_bound, bound_prec, precs, means): - n_components, n_features = means.shape - n_samples = X.shape[0] - bound = np.empty((n_samples, n_components)) - bound[:] = bound_prec + initial_bound - for k in xrange(n_components): - d = X - means[k] - d **= 2 - bound[:, k] -= 0.5 * np.sum(d * precs[k], axis=1) - return bound - - -def _bound_state_loglik_tied(X, initial_bound, bound_prec, precs, means): - n_components, n_features = means.shape - n_samples = X.shape[0] - bound = np.empty((n_samples, n_components)) - bound[:] = bound_prec + initial_bound - # Transform the data to be able to apply standard Euclidean distance, - # rather than Mahlanobis distance - sqrt_cov = linalg.cholesky(precs) - means = np.dot(means, sqrt_cov.T) - X = np.dot(X, sqrt_cov.T) - bound -= 0.5 * euclidean_distances(X, means, squared=True) - return bound +############################################################################## -def _bound_state_loglik_full(X, initial_bound, bound_prec, precs, means): +def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type): + """Update the bound with likelihood terms, for standard covariance types""" n_components, n_features = means.shape n_samples = X.shape[0] bound = np.empty((n_samples, n_components)) - bound[:] = bound_prec + initial_bound - for k in xrange(n_components): - d = X - means[k] - sqrt_cov = linalg.cholesky(precs[k]) - d = np.dot(d, sqrt_cov.T) - d **= 2 - bound[:, k] -= 0.5 * d.sum(axis=-1) + bound[:] = initial_bound + if covariance_type == 'diag': + for k in xrange(n_components): + d = X - means[k] + bound[:, k] -= 0.5 * np.sum(d * d * precs[k], axis=1) + elif covariance_type == 'spherical': + for k in xrange(n_components): + bound[:, k] -= 0.5 * precs[k] * (((X - means[k]) ** 2).sum(axis=-1) + + n_features) + elif covariance_type == 'tied': + sqrt_cov = linalg.cholesky(precs) + means = np.dot(means, sqrt_cov.T) + X = np.dot(X, sqrt_cov.T) + bound -= 0.5 * euclidean_distances(X, means, squared=True) + elif covariance_type == 'full': + for k in xrange(n_components): + d = X - means[k] + # not: choleksy is useless here + sqrt_cov = linalg.cholesky(precs[k]) + d = np.dot(d, sqrt_cov.T) + d **= 2 + bound[:, k] -= 0.5 * d.sum(axis=-1) + return bound -_BOUND_STATE_LOGLIK_DICT = dict( - spherical=_bound_state_loglik_spherical, - diag=_bound_state_loglik_diag, - tied=_bound_state_loglik_tied, - full=_bound_state_loglik_full) - - class DPGMM(GMM): """Variational Inference for the Infinite Gaussian Mixture Model. @@ -260,26 +235,23 @@ class DPGMM(GMM): if X.ndim == 1: X = X[:, np.newaxis] z = np.zeros((X.shape[0], self.n_components)) - sd = digamma(self._gamma.T[1] + self._gamma.T[2]) - dgamma1 = digamma(self._gamma.T[1]) - sd + sd = digamma(self.gamma_.T[1] + self.gamma_.T[2]) + dgamma1 = digamma(self.gamma_.T[1]) - sd dgamma2 = np.zeros(self.n_components) - dgamma2[0] = digamma(self._gamma[0, 2]) - digamma(self._gamma[0, 1] + - self._gamma[0, 2]) + dgamma2[0] = digamma(self.gamma_[0, 2]) - digamma(self.gamma_[0, 1] + + self.gamma_[0, 2]) for j in xrange(1, self.n_components): - dgamma2[j] = dgamma2[j - 1] + digamma(self._gamma[j - 1, 2]) + dgamma2[j] = dgamma2[j - 1] + digamma(self.gamma_[j - 1, 2]) dgamma2[j] -= sd[j - 1] dgamma = dgamma1 + dgamma2 # Free memory and developers cognitive load: del dgamma1, dgamma2, sd - try: - _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type] - except KeyError: + if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']: raise NotImplementedError("This ctype is not implemented: %s" % self.covariance_type) - - p = _bound_state_loglik(X, self._initial_bound, - self._bound_prec, self.precs_, self.means_) + p = _bound_state_log_lik(X, self._initial_bound + self.bound_prec_, + self.precs_, self.means_, self.covariance_type) z = p + dgamma z = log_normalize(z, axis=-1) bound = np.sum(z * p, axis=-1) @@ -288,11 +260,11 @@ class DPGMM(GMM): def _update_concentration(self, z): """Update the concentration parameters for each cluster""" sz = np.sum(z, axis=0) - self._gamma.T[1] = 1. + sz - self._gamma.T[2].fill(0) + self.gamma_.T[1] = 1. + sz + self.gamma_.T[2].fill(0) for i in xrange(self.n_components - 2, -1, -1): - self._gamma[i, 2] = self._gamma[i + 1, 2] + sz[i] - self._gamma.T[2] += self.alpha + self.gamma_[i, 2] = self.gamma_[i + 1, 2] + sz[i] + self.gamma_.T[2] += self.alpha def _update_means(self, X, z): """Update the variational distributions for the means""" @@ -317,63 +289,63 @@ class DPGMM(GMM): """Update the variational distributions for the precisions""" n_features = X.shape[1] if self.covariance_type == 'spherical': - self._a = 0.5 * n_features * np.sum(z, axis=0) + self.dof_ = 0.5 * n_features * np.sum(z, axis=0) for k in xrange(self.n_components): # XXX: how to avoid this huge temporary matrix in memory dif = (X - self.means_[k]) - self._b[k] = 1. + self.scale_[k] = 1. d = np.sum(dif * dif, axis=1) - self._b[k] += 0.5 * np.sum(z.T[k] * (d + n_features)) - self._bound_prec[k] = ( + self.scale_[k] += 0.5 * np.sum(z.T[k] * (d + n_features)) + self.bound_prec_[k] = ( 0.5 * n_features * ( - digamma(self._a[k]) - np.log(self._b[k]))) - self.precs_ = self._a / self._b + digamma(self.dof_[k]) - np.log(self.scale_[k]))) + self.precs_ = self.dof_ / self.scale_ elif self.covariance_type == 'diag': for k in xrange(self.n_components): - self._a[k].fill(1. + 0.5 * np.sum(z.T[k], axis=0)) + self.dof_[k].fill(1. + 0.5 * np.sum(z.T[k], axis=0)) ddif = (X - self.means_[k]) # see comment above for d in xrange(n_features): - self._b[k, d] = 1. + self.scale_[k, d] = 1. dd = ddif.T[d] * ddif.T[d] - self._b[k, d] += 0.5 * np.sum(z.T[k] * (dd + 1)) - self.precs_[k] = self._a[k] / self._b[k] - self._bound_prec[k] = 0.5 * np.sum(digamma(self._a[k]) - - np.log(self._b[k])) - self._bound_prec[k] -= 0.5 * np.sum(self.precs_[k]) + self.scale_[k, d] += 0.5 * np.sum(z.T[k] * (dd + 1)) + self.precs_[k] = self.dof_[k] / self.scale_[k] + self.bound_prec_[k] = 0.5 * np.sum(digamma(self.dof_[k]) + - np.log(self.scale_[k])) + self.bound_prec_[k] -= 0.5 * np.sum(self.precs_[k]) elif self.covariance_type == 'tied': - self._a = 2 + X.shape[0] + n_features - self._B = (X.shape[0] + 1) * np.identity(n_features) + self.dof_ = 2 + X.shape[0] + n_features + self.scale_ = (X.shape[0] + 1) * np.identity(n_features) for i in xrange(X.shape[0]): for k in xrange(self.n_components): dif = X[i] - self.means_[k] - self._B += z[i, k] * np.dot(dif.reshape((-1, 1)), + self.scale_ += z[i, k] * np.dot(dif.reshape((-1, 1)), dif.reshape((1, -1))) - self._B = linalg.pinv(self._B) - self.precs_ = self._a * self._B - self._detB = linalg.det(self._B) - self._bound_prec = 0.5 * detlog_wishart( - self._a, self._B, self._detB, n_features) - self._bound_prec -= 0.5 * self._a * np.trace(self._B) + self.scale_ = linalg.pinv(self.scale_) + self.precs_ = self.dof_ * self.scale_ + self.det_scale_ = linalg.det(self.scale_) + self.bound_prec_ = 0.5 * wishart_log_det( + self.dof_, self.scale_, self.det_scale_, n_features) + self.bound_prec_ -= 0.5 * self.dof_ * np.trace(self.scale_) elif self.covariance_type == 'full': for k in xrange(self.n_components): T = np.sum(z.T[k]) - self._a[k] = 2 + T + n_features - self._B[k] = (T + 1) * np.identity(n_features) + self.dof_[k] = 2 + T + n_features + self.scale_[k] = (T + 1) * np.identity(n_features) for i in xrange(X.shape[0]): dif = X[i] - self.means_[k] - self._B[k] += z[i, k] * np.dot(dif.reshape((-1, 1)), + self.scale_[k] += z[i, k] * np.dot(dif.reshape((-1, 1)), dif.reshape((1, -1))) - self._B[k] = linalg.pinv(self._B[k]) - self.precs_[k] = self._a[k] * self._B[k] - self._detB[k] = linalg.det(self._B[k]) - self._bound_prec[k] = 0.5 * detlog_wishart(self._a[k], - self._B[k], - self._detB[k], + self.scale_[k] = linalg.pinv(self.scale_[k]) + self.precs_[k] = self.dof_[k] * self.scale_[k] + self.det_scale_[k] = linalg.det(self.scale_[k]) + self.bound_prec_[k] = 0.5 * wishart_log_det(self.dof_[k], + self.scale_[k], + self.det_scale_[k], n_features) - self._bound_prec[k] -= 0.5 * self._a[k] * np.trace(self._B[k]) + self.bound_prec_[k] -= 0.5 * self.dof_[k] * np.trace(self.scale_[k]) def _monitor(self, X, z, n, end=False): """Monitor the lower bound during iteration @@ -385,7 +357,7 @@ class DPGMM(GMM): if self.verbose: print "Bound after updating %8s: %f" % (n, self.lower_bound(X, z)) if end == True: - print "Cluster proportions:", self._gamma.T[1] + print "Cluster proportions:", self.gamma_.T[1] print "covariance_type:", self._covariance_type def _do_mstep(self, X, z, params): @@ -404,24 +376,24 @@ class DPGMM(GMM): def _initialize_gamma(self): "Initializes the concentration parameters" - self._gamma = self.alpha * np.ones((self.n_components, 3)) + self.gamma_ = self.alpha * np.ones((self.n_components, 3)) def _bound_concentration(self): "The variational lower bound for the concentration parameter." logprior = 0. for k in xrange(self.n_components): logprior = gammaln(self.alpha) - logprior += (self.alpha - 1) * (digamma(self._gamma[k, 2]) - - digamma(self._gamma[k, 1] + - self._gamma[k, 2])) - logprior += -gammaln(self._gamma[k, 1] + self._gamma[k, 2]) - logprior += gammaln(self._gamma[k, 1]) + gammaln(self._gamma[k, 2]) - logprior -= (self._gamma[k, 1] - 1) * (digamma(self._gamma[k, 1]) - - digamma(self._gamma[k, 1] + - self._gamma[k, 2])) - logprior -= (self._gamma[k, 2] - 1) * (digamma(self._gamma[k, 2]) - - digamma(self._gamma[k, 1] + - self._gamma[k, 2])) + logprior += (self.alpha - 1) * (digamma(self.gamma_[k, 2]) - + digamma(self.gamma_[k, 1] + + self.gamma_[k, 2])) + logprior += -gammaln(self.gamma_[k, 1] + self.gamma_[k, 2]) + logprior += gammaln(self.gamma_[k, 1]) + gammaln(self.gamma_[k, 2]) + logprior -= (self.gamma_[k, 1] - 1) * (digamma(self.gamma_[k, 1]) - + digamma(self.gamma_[k, 1] + + self.gamma_[k, 2])) + logprior -= (self.gamma_[k, 2] - 1) * (digamma(self.gamma_[k, 2]) - + digamma(self.gamma_[k, 1] + + self.gamma_[k, 2])) return logprior def _bound_means(self): @@ -437,7 +409,7 @@ class DPGMM(GMM): logprior -= wishart_logz(n_features, np.identity(n_features), 1, n_features) - logprior += 0.5 * (a - 1) * detlog_wishart(a, B, detB, n_features) + logprior += 0.5 * (a - 1) * wishart_log_det(a, B, detB, n_features) logprior += 0.5 * a * np.trace(B) return logprior @@ -445,29 +417,29 @@ class DPGMM(GMM): logprior = 0. if self.covariance_type == 'spherical': for k in xrange(self.n_components): - logprior += gammaln(self._a[k]) - logprior -= (self._a[k] - 1) * digamma(max(0.5, self._a[k])) - logprior += - np.log(self._b[k]) + self._a[k] - self.precs_[k] + logprior += gammaln(self.dof_[k]) + logprior -= (self.dof_[k] - 1) * digamma(max(0.5, self.dof_[k])) + logprior += - np.log(self.scale_[k]) + self.dof_[k] - self.precs_[k] elif self.covariance_type == 'diag': for k in xrange(self.n_components): for d in xrange(self.means.shape[1]): - logprior += gammaln(self._a[k, d]) - logprior -= (self._a[k, d] - 1) * digamma(self._a[k, d]) - logprior -= np.log(self._b[k, d]) - logprior += self._a[k, d] - self.precs_[k, d] + logprior += gammaln(self.dof_[k, d]) + logprior -= (self.dof_[k, d] - 1) * digamma(self.dof_[k, d]) + logprior -= np.log(self.scale_[k, d]) + logprior += self.dof_[k, d] - self.precs_[k, d] elif self.covariance_type == 'tied': - logprior += self._bound_wishart(self._a, self._B, self._detB) + logprior += self._bound_wishart(self.dof_, self.scale_, self.det_scale_) elif self.covariance_type == 'full': for k in xrange(self.n_components): - logprior += self._bound_wishart(self._a[k], - self._B[k], - self._detB[k]) + logprior += self._bound_wishart(self.dof_[k], + self.scale_[k], + self.det_scale_[k]) return logprior def _bound_proportions(self, z): - dg12 = digamma(self._gamma.T[1] + self._gamma.T[2]) - dg1 = digamma(self._gamma.T[1]) - dg12 - dg2 = digamma(self._gamma.T[2]) - dg12 + dg12 = digamma(self.gamma_.T[1] + self.gamma_.T[2]) + dg1 = digamma(self.gamma_.T[1]) - dg12 + dg2 = digamma(self.gamma_.T[2]) - dg12 cz = np.cumsum(z[:, ::-1], axis=-1)[:, -2::-1] logprior = np.sum(cz * dg2[:-1]) + np.sum(z * dg1) @@ -484,17 +456,16 @@ class DPGMM(GMM): return logprior def lower_bound(self, X, z): - try: - _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type] - except KeyError: + if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']: raise NotImplementedError("This ctype is not implemented: %s" % self.covariance_type) + X = np.asarray(X) if X.ndim == 1: X = X[:, np.newaxis] - c = np.sum(z * _bound_state_loglik( - X, self._initial_bound, self._bound_prec, self.precs_, - self.means_)) + c = np.sum(z * _bound_state_log_lik( + X, self._initial_bound + self.bound_prec_, + self.precs_, self.means_, self.covariance_type)) return c + self._logprior(z) @@ -554,44 +525,47 @@ class DPGMM(GMM): if 'c' in init_params or not hasattr(self, 'covars'): if self.covariance_type == 'spherical': - self._a = np.ones(self.n_components) - self._b = np.ones(self.n_components) + self.dof_ = np.ones(self.n_components) + self.scale_ = np.ones(self.n_components) self.precs_ = np.ones(self.n_components) - self._bound_prec = (0.5 * n_features * - (digamma(self._a) - - np.log(self._b))) + self.bound_prec_ = ( + 0.5 * n_features * (digamma(self.dof_) + - np.log(self.scale_))) elif self.covariance_type == 'diag': - self._a = 1 + 0.5 * n_features - self._a *= np.ones((self.n_components, n_features)) - self._b = np.ones((self.n_components, n_features)) + self.dof_ = 1 + 0.5 * n_features + self.dof_ *= np.ones((self.n_components, n_features)) + self.scale_ = np.ones((self.n_components, n_features)) self.precs_ = np.ones((self.n_components, n_features)) - self._bound_prec = np.zeros(self.n_components) + self.bound_prec_ = np.zeros(self.n_components) for k in xrange(self.n_components): - self._bound_prec[k] = 0.5 * np.sum(digamma(self._a[k]) - - np.log(self._b[k])) - self._bound_prec[k] -= 0.5 * np.sum(self.precs_[k]) + self.bound_prec_[k] = ( + 0.5 * np.sum(digamma(self.dof_[k]) + - np.log(self.scale_[k]))) + self.bound_prec_[k] -= 0.5 * np.sum(self.precs_[k]) elif self.covariance_type == 'tied': - self._a = 1. - self._B = np.identity(n_features) + self.dof_ = 1. + self.scale_ = np.identity(n_features) self.precs_ = np.identity(n_features) - self._detB = 1. - self._bound_prec = 0.5 * detlog_wishart( - self._a, self._B, self._detB, n_features) - self._bound_prec -= 0.5 * self._a * np.trace(self._B) + self.det_scale_ = 1. + self.bound_prec_ = 0.5 * wishart_log_det( + self.dof_, self.scale_, self.det_scale_, n_features) + self.bound_prec_ -= 0.5 * self.dof_ * np.trace(self.scale_) elif self.covariance_type == 'full': - self._a = (1 + self.n_components + X.shape[0]) - self._a *= np.ones(self.n_components) - self._B = [2 * np.identity(n_features) + self.dof_ = (1 + self.n_components + X.shape[0]) + self.dof_ *= np.ones(self.n_components) + self.scale_ = [2 * np.identity(n_features) for i in xrange(self.n_components)] self.precs_ = [np.identity(n_features) for i in xrange(self.n_components)] - self._detB = np.ones(self.n_components) - self._bound_prec = np.zeros(self.n_components) + self.det_scale_ = np.ones(self.n_components) + self.bound_prec_ = np.zeros(self.n_components) for k in xrange(self.n_components): - self._bound_prec[k] = detlog_wishart( - self._a[k], self._B[k], self._detB[k], n_features) - self._bound_prec[k] -= self._a[k] * np.trace(self._B[k]) - self._bound_prec[k] *= 0.5 + self.bound_prec_[k] = wishart_log_det( + self.dof_[k],self.scale_[k], self.det_scale_[k], + n_features) + self.bound_prec_[k] -= (self.dof_[k] * + np.trace(self.scale_[k])) + self.bound_prec_[k] *= 0.5 logprob = [] # reset self.converged_ to False @@ -677,8 +651,8 @@ class VBGMM(DPGMM): random_state=None, thresh=1e-2, verbose=False, min_covar=None): super(VBGMM, self).__init__( - n_components, covariance_type, random_state=random_state, thresh=thresh, - verbose=verbose, min_covar=min_covar) + n_components, covariance_type, random_state=random_state, + thresh=thresh, verbose=verbose, min_covar=min_covar) self.alpha = float(alpha) / n_components def eval(self, X): @@ -711,15 +685,15 @@ class VBGMM(DPGMM): z = np.zeros((X.shape[0], self.n_components)) p = np.zeros(self.n_components) bound = np.zeros(X.shape[0]) - dg = digamma(self._gamma) - digamma(np.sum(self._gamma)) - try: - _bound_state_loglik = _BOUND_STATE_LOGLIK_DICT[self.covariance_type] - except KeyError: + dg = digamma(self.gamma_) - digamma(np.sum(self.gamma_)) + + if self.covariance_type not in ['full', 'tied', 'diag', 'spherical']: raise NotImplementedError("This ctype is not implemented: %s" % self.covariance_type) - - p = _bound_state_loglik(X, self._initial_bound, - self._bound_prec, self.precs_, self.means_) + p = _bound_state_log_lik( + X, self._initial_bound + self.bound_prec_, + self.precs_, self.means_, self.covariance_type) + z = p + dg z = log_normalize(z, axis=-1) bound = np.sum(z * p, axis=-1) @@ -727,15 +701,15 @@ class VBGMM(DPGMM): def _update_concentration(self, z): for i in xrange(self.n_components): - self._gamma[i] = self.alpha + np.sum(z.T[i]) + self.gamma_[i] = self.alpha + np.sum(z.T[i]) def _initialize_gamma(self): - self._gamma = self.alpha * np.ones(self.n_components) + self.gamma_ = self.alpha * np.ones(self.n_components) def _bound_proportions(self, z): logprior = 0. - dg = digamma(self._gamma) - dg -= digamma(np.sum(self._gamma)) + dg = digamma(self.gamma_) + dg -= digamma(np.sum(self.gamma_)) logprior += np.sum(dg.reshape((-1, 1)) * z.T) z_non_zeros = z[z > np.finfo(np.float32).eps] logprior -= np.sum(z_non_zeros * np.log(z_non_zeros)) @@ -743,12 +717,12 @@ class VBGMM(DPGMM): def _bound_concentration(self): logprior = 0. - logprior = gammaln(np.sum(self._gamma)) - gammaln(self.n_components + logprior = gammaln(np.sum(self.gamma_)) - gammaln(self.n_components * self.alpha) - logprior -= np.sum(gammaln(self._gamma) - gammaln(self.alpha)) - sg = digamma(np.sum(self._gamma)) - logprior += np.sum((self._gamma - self.alpha) - * (digamma(self._gamma) - sg)) + logprior -= np.sum(gammaln(self.gamma_) - gammaln(self.alpha)) + sg = digamma(np.sum(self.gamma_)) + logprior += np.sum((self.gamma_ - self.alpha) + * (digamma(self.gamma_) - sg)) return logprior def _monitor(self, X, z, n, end=False): @@ -761,5 +735,5 @@ class VBGMM(DPGMM): if self.verbose: print "Bound after updating %8s: %f" % (n, self.lower_bound(X, z)) if end == True: - print "Cluster proportions:", self._gamma + print "Cluster proportions:", self.gamma_ print "covariance_type:", self._covariance_type