diff --git a/scikits/learn/gmm.py b/scikits/learn/gmm.py index 1fb5f4f1081e39d07c94c767a9d405a84eb6ce07..1ce12b6e4f4fe75fdbce6e93b7c4a8d94b3e4044 100644 --- a/scikits/learn/gmm.py +++ b/scikits/learn/gmm.py @@ -84,7 +84,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1): Parameters ---------- - mean : array_like, shape (n_dim,) + mean : array_like, shape (n_features,) Mean of the distribution. covars : array_like Covariance of the distribution. The shape depends on `cvtype`: @@ -99,7 +99,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1): Returns ------- - obs : array, shape (n, n_dim) + obs : array, shape (n, n_features) Randomly generated sample """ ndim = len(mean) @@ -133,21 +133,21 @@ class GMM(BaseEstimator): cvtype : string (read-only) String describing the type of covariance parameters used by the GMM. Must be one of 'spherical', 'tied', 'diag', 'full'. - n_dim : int + n_features : int Dimensionality of the Gaussians. n_states : int (read-only) Number of mixture components. weights : array, shape (`n_states`,) Mixing weights for each mixture component. - means : array, shape (`n_states`, `n_dim`) + means : array, shape (`n_states`, `n_features`) Mean parameters for each mixture component. covars : array Covariance parameters for each mixture component. The shape depends on `cvtype`: - (`n_states`,) if 'spherical', - (`n_dim`, `n_dim`) if 'tied', - (`n_states`, `n_dim`) if 'diag', - (`n_states`, `n_dim`, `n_dim`) if 'full' + (`n_states`,) if 'spherical', + (`n_features`, `n_features`) if 'tied', + (`n_states`, `n_features`) if 'diag', + (`n_states`, `n_features`, `n_features`) if 'full' Methods ------- @@ -255,11 +255,11 @@ class GMM(BaseEstimator): elif self.cvtype == 'tied': return [self._covars] * self._n_states elif self.cvtype == 'spherical': - return [np.eye(self.n_dim) * f for f in self._covars] + return [np.eye(self.n_features) * f for f in self._covars] def _set_covars(self, covars): covars = np.asanyarray(covars) - _validate_covars(covars, self._cvtype, self._n_states, self.n_dim) + _validate_covars(covars, self._cvtype, self._n_states, self.n_features) self._covars = covars covars = property(_get_covars, _set_covars) @@ -270,11 +270,11 @@ class GMM(BaseEstimator): def _set_means(self, means): means = np.asarray(means) - if hasattr(self, 'n_dim') and \ - means.shape != (self._n_states, self.n_dim): - raise ValueError('means must have shape (n_states, n_dim)') + if hasattr(self, 'n_features') and \ + means.shape != (self._n_states, self.n_features): + raise ValueError('means must have shape (n_states, n_features)') self._means = means.copy() - self.n_dim = self._means.shape[1] + self.n_features = self._means.shape[1] means = property(_get_means, _set_means) @@ -301,9 +301,9 @@ class GMM(BaseEstimator): Parameters ---------- - obs : array_like, shape (n, n_dim) - List of n_dim-dimensional data points. Each row corresponds to a - single data point. + obs : array_like, shape (n, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. Returns ------- @@ -325,9 +325,9 @@ class GMM(BaseEstimator): Parameters ---------- - obs : array_like, shape (n, n_dim) - List of n_dim-dimensional data points. Each row corresponds to a - single data point. + obs : array_like, shape (n, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. Returns ------- @@ -342,9 +342,9 @@ class GMM(BaseEstimator): Parameters ---------- - obs : array_like, shape (n, n_dim) - List of n_dim-dimensional data points. Each row corresponds to a - single data point. + obs : array_like, shape (n, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. Returns ------- @@ -397,13 +397,13 @@ class GMM(BaseEstimator): Returns ------- - obs : array_like, shape (n, n_dim) + obs : array_like, shape (n, n_features) List of samples """ weight_pdf = self.weights weight_cdf = np.cumsum(weight_pdf) - obs = np.empty((n, self.n_dim)) + obs = np.empty((n, self.n_features)) for x in xrange(n): rand = np.random.rand() c = (weight_cdf > rand).argmax() @@ -427,9 +427,9 @@ class GMM(BaseEstimator): Parameters ---------- - X : array_like, shape (n, n_dim) - List of n_dim-dimensional data points. Each row corresponds to a - single data point. + X : array_like, shape (n, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. n_iter : int, optional Number of EM iterations to perform. min_covar : float, optional @@ -451,17 +451,17 @@ class GMM(BaseEstimator): X = np.asanyarray(X) - if hasattr(self, 'n_dim') and self.n_dim != X.shape[1]: + if hasattr(self, 'n_features') and self.n_features != X.shape[1]: raise ValueError('Unexpected number of dimensions, got %s but ' - 'expected %s' % (X.shape[1], self.n_dim)) + 'expected %s' % (X.shape[1], self.n_features)) - self.n_dim = X.shape[1] + self.n_features = X.shape[1] if 'm' in init_params: self._means = cluster.KMeans( k=self._n_states).fit(X).cluster_centers_ elif not hasattr(self, 'means'): - self._means = np.zeros((self.n_states, self.n_dim)) + self._means = np.zeros((self.n_states, self.n_features)) if 'w' in init_params or not hasattr(self, 'weights'): self.weights = np.tile(1.0 / self._n_states, self._n_states) @@ -474,7 +474,7 @@ class GMM(BaseEstimator): cv, self._cvtype, self._n_states) elif not hasattr(self, 'covars'): self.covars = _distribute_covar_matrix_to_match_cvtype( - np.eye(self.n_dim), cvtype, n_states) + np.eye(self.n_features), cvtype, n_states) # EM algorithm logprob = [] @@ -637,13 +637,13 @@ def _covar_mstep_spherical(*args): def _covar_mstep_full(gmm, obs, posteriors, avg_obs, norm, min_covar): # Eq. 12 from K. Murphy, "Fitting a Conditional Linear Gaussian # Distribution" - cv = np.empty((gmm._n_states, gmm.n_dim, gmm.n_dim)) + cv = np.empty((gmm._n_states, gmm.n_features, gmm.n_features)) for c in xrange(gmm._n_states): post = posteriors[:,c] avg_cv = np.dot(post * obs.T, obs) / post.sum() mu = gmm._means[c][np.newaxis] cv[c] = (avg_cv - np.dot(mu.T, mu) - + min_covar * np.eye(gmm.n_dim)) + + min_covar * np.eye(gmm.n_features)) return cv @@ -656,7 +656,7 @@ def _covar_mstep_tied(gmm, obs, posteriors, avg_obs, norm, min_covar): # Eq. 15 from K. Murphy, "Fitting a Conditional Linear Gaussian avg_obs2 = np.dot(obs.T, obs) avg_means2 = np.dot(gmm._means.T, gmm._means) - return (avg_obs2 - avg_means2 + min_covar * np.eye(gmm.n_dim)) + return (avg_obs2 - avg_means2 + min_covar * np.eye(gmm.n_features)) def _covar_mstep_slow(gmm, obs, posteriors, avg_obs, norm, min_covar): @@ -665,13 +665,13 @@ def _covar_mstep_slow(gmm, obs, posteriors, avg_obs, norm, min_covar): for c in xrange(gmm._n_states): mu = gmm._means[c] #cv = np.dot(mu.T, mu) - avg_obs2 = np.zeros((gmm.n_dim, gmm.n_dim)) + avg_obs2 = np.zeros((gmm.n_features, gmm.n_features)) for t,o in enumerate(obs): avg_obs2 += posteriors[t,c] * np.outer(o, o) cv = (avg_obs2 / w[c] - 2 * np.outer(avg_obs[c] / w[c], mu) + np.outer(mu, mu) - + min_covar * np.eye(gmm.n_dim)) + + min_covar * np.eye(gmm.n_features)) if gmm.cvtype == 'spherical': covars[c] = np.diag(cv).mean() elif gmm.cvtype == 'diag': diff --git a/scikits/learn/hmm.py b/scikits/learn/hmm.py index b851762384bc2620c84047a040112daf0a08fd85..1fffddcda54e2c849bfd14d864298082df90bf14 100644 --- a/scikits/learn/hmm.py +++ b/scikits/learn/hmm.py @@ -992,9 +992,9 @@ class GMMHMM(_BaseHMM): gmm_logprob, gmm_posteriors = g.eval(obs) gmm_posteriors *= posteriors[:,state][:,np.newaxis] tmpgmm = GMM(g.n_states, cvtype=g.cvtype) - tmpgmm.n_dim = g.n_dim + tmpgmm.n_features = g.n_features tmpgmm.covars = _distribute_covar_matrix_to_match_cvtype( - np.eye(g.n_dim), g.cvtype, g.n_states) + np.eye(g.n_features), g.cvtype, g.n_states) norm = tmpgmm._do_mstep(obs, gmm_posteriors, params) stats['norm'][state] += norm @@ -1022,8 +1022,9 @@ class GMMHMM(_BaseHMM): g.means = stats['means'][state] / norm[:,np.newaxis] if 'c' in params: if g.cvtype == 'tied': - g.covars = (stats['covars'][state] - + covars_prior * np.eye(g.n_dim)) / norm.sum() + g.covars = ((stats['covars'][state] + + covars_prior * np.eye(g.n_features)) + / norm.sum()) else: cvnorm = np.copy(norm) shape = np.ones(g._covars.ndim) @@ -1033,7 +1034,7 @@ class GMMHMM(_BaseHMM): g.covars = (stats['covars'][state] + covars_prior) / cvnorm elif g.cvtype == 'full': - eye = np.eye(g.n_dim) + eye = np.eye(g.n_features) g.covars = ((stats['covars'][state] + covars_prior * eye[np.newaxis,:,:]) / cvnorm) diff --git a/scikits/learn/tests/test_gmm.py b/scikits/learn/tests/test_gmm.py index 215944efe5c26f30ef87596d5f3f2d81ef59b5e5..10f14f94ee3c5d25cdc5f5b461d67fc73d7b11f5 100644 --- a/scikits/learn/tests/test_gmm.py +++ b/scikits/learn/tests/test_gmm.py @@ -55,10 +55,10 @@ def test_sample_gaussian(): is diagonal, spherical and full """ - n_dim, n_samples = 2, 300 + n_features, n_samples = 2, 300 axis = 1 - mu = np.random.randint(10) * np.random.rand(n_dim) - cv = (np.random.rand(n_dim) + 1.0) ** 2 + mu = np.random.randint(10) * np.random.rand(n_features) + cv = (np.random.rand(n_features) + 1.0) ** 2 samples = gmm.sample_gaussian(mu, cv, cvtype='diag', n=n_samples) @@ -70,11 +70,11 @@ def test_sample_gaussian(): samples = gmm.sample_gaussian(mu, cv, cvtype='spherical', n=n_samples) assert np.allclose(samples.mean(axis), mu, atol=0.3) - assert np.allclose(samples.var(axis), np.repeat(cv, n_dim), atol=0.5) + assert np.allclose(samples.var(axis), np.repeat(cv, n_features), atol=0.5) # and for full covariances - A = np.random.randn(n_dim, n_dim) - cv = np.dot(A.T, A) + np.eye(n_dim) + A = np.random.randn(n_features, n_features) + cv = np.dot(A.T, A) + np.eye(n_features) samples = gmm.sample_gaussian(mu, cv, cvtype='full', n=n_samples) assert np.allclose(samples.mean(axis), mu, atol=0.3) assert np.allclose(np.cov(samples), cv, atol=0.7) @@ -95,10 +95,10 @@ def test_lmvnpdf_diag(): compare it to the vectorized version (gmm.lmvnpdf) to test for correctness """ - n_dim, n_states, n_obs = 2, 3, 10 - mu = np.random.randint(10) * np.random.rand(n_states, n_dim) - cv = (np.random.rand(n_states, n_dim) + 1.0) ** 2 - obs = np.random.randint(10) * np.random.rand(n_obs, n_dim) + n_features, n_states, n_obs = 2, 3, 10 + mu = np.random.randint(10) * np.random.rand(n_states, n_features) + cv = (np.random.rand(n_states, n_features) + 1.0) ** 2 + obs = np.random.randint(10) * np.random.rand(n_obs, n_features) ref = _naive_lmvnpdf_diag(obs, mu, cv) lpr = gmm.lmvnpdf(obs, mu, cv, 'diag') @@ -106,13 +106,13 @@ def test_lmvnpdf_diag(): def test_lmvnpdf_spherical(): - n_dim, n_states, n_obs = 2, 3, 10 + n_features, n_states, n_obs = 2, 3, 10 - mu = np.random.randint(10) * np.random.rand(n_states, n_dim) + mu = np.random.randint(10) * np.random.rand(n_states, n_features) spherecv = np.random.rand(n_states, 1) ** 2 + 1 - obs = np.random.randint(10) * np.random.rand(n_obs, n_dim) + obs = np.random.randint(10) * np.random.rand(n_obs, n_features) - cv = np.tile(spherecv, (n_dim, 1)) + cv = np.tile(spherecv, (n_features, 1)) reference = _naive_lmvnpdf_diag(obs, mu, cv) lpr = gmm.lmvnpdf(obs, mu, spherecv, 'spherical') assert_array_almost_equal(lpr, reference) @@ -120,11 +120,11 @@ def test_lmvnpdf_spherical(): def test_lmvnpdf_full(): - n_dim, n_states, n_obs = 2, 3, 10 + n_features, n_states, n_obs = 2, 3, 10 - mu = np.random.randint(10) * np.random.rand(n_states, n_dim) - cv = (np.random.rand(n_states, n_dim) + 1.0) ** 2 - obs = np.random.randint(10) * np.random.rand(n_obs, n_dim) + mu = np.random.randint(10) * np.random.rand(n_states, n_features) + cv = (np.random.rand(n_states, n_features) + 1.0) ** 2 + obs = np.random.randint(10) * np.random.rand(n_obs, n_features) fullcv = np.array([np.diag(x) for x in cv]) @@ -135,12 +135,12 @@ def test_lmvnpdf_full(): def test_GMM_attributes(): - n_states, n_dim = 10, 4 + n_states, n_features = 10, 4 cvtype = 'diag' g = gmm.GMM(n_states, cvtype) weights = np.random.rand(n_states) weights = weights / weights.sum() - means = np.random.randint(-20, 20, (n_states, n_dim)) + means = np.random.randint(-20, 20, (n_states, n_features)) assert g.n_states == n_states assert g.cvtype == cvtype @@ -151,34 +151,34 @@ def test_GMM_attributes(): 2 * weights) assert_raises(ValueError, g.__setattr__, 'weights', []) assert_raises(ValueError, g.__setattr__, 'weights', - np.zeros((n_states - 2, n_dim))) + np.zeros((n_states - 2, n_features))) g.means = means assert_array_almost_equal(g.means, means) assert_raises(ValueError, g.__setattr__, 'means', []) assert_raises(ValueError, g.__setattr__, 'means', - np.zeros((n_states - 2, n_dim))) + np.zeros((n_states - 2, n_features))) - covars = (0.1 + 2 * np.random.rand(n_states, n_dim)) ** 2 + covars = (0.1 + 2 * np.random.rand(n_states, n_features)) ** 2 g._covars = covars assert_array_almost_equal(g._covars, covars) assert_raises(ValueError, g.__setattr__, 'covars', []) assert_raises(ValueError, g.__setattr__, 'covars', - np.zeros((n_states - 2, n_dim))) + np.zeros((n_states - 2, n_features))) assert_raises(ValueError, gmm.GMM, n_states=20, cvtype='badcvtype') class GMMTester(): n_states = 10 - n_dim = 4 + n_features = 4 weights = np.random.rand(n_states) weights = weights / weights.sum() - means = np.random.randint(-20, 20, (n_states, n_dim)) - I = np.eye(n_dim) + means = np.random.randint(-20, 20, (n_states, n_features)) + I = np.eye(n_features) covars = {'spherical': (0.1 + 2 * np.random.rand(n_states)) ** 2, - 'tied': _generate_random_spd_matrix(n_dim) + 5 * I, - 'diag': (0.1 + 2 * np.random.rand(n_states, n_dim)) ** 2, - 'full': np.array([_generate_random_spd_matrix(n_dim) + 5 * I + 'tied': _generate_random_spd_matrix(n_features) + 5 * I, + 'diag': (0.1 + 2 * np.random.rand(n_states, n_features)) ** 2, + 'full': np.array([_generate_random_spd_matrix(n_features) + 5 * I for x in xrange(n_states)])} @@ -192,7 +192,7 @@ class GMMTester(): gaussidx = np.repeat(range(self.n_states), 5) nobs = len(gaussidx) - obs = np.random.randn(nobs, self.n_dim) + g.means[gaussidx] + obs = np.random.randn(nobs, self.n_features) + g.means[gaussidx] ll, posteriors = g.eval(obs) @@ -210,7 +210,7 @@ class GMMTester(): g.weights = self.weights samples = g.rvs(n) - self.assertEquals(samples.shape, (n, self.n_dim)) + self.assertEquals(samples.shape, (n, self.n_features)) def test_train(self, params='wmc'): g = gmm.GMM(self.n_states, self.cvtype)