diff --git a/scikits/learn/mixture.py b/scikits/learn/mixture.py index 9798f8552b12a16c2bc98649b2e18816acd313ed..af843a06b06b8e9ea7e50db2afe238c3be2e8830 100644 --- a/scikits/learn/mixture.py +++ b/scikits/learn/mixture.py @@ -99,7 +99,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1): Returns ------- - obs : array, shape (n, n_features) + obs : array, shape (n_features, n) Randomly generated sample """ ndim = len(mean) @@ -403,14 +403,19 @@ class GMM(BaseEstimator): weight_cdf = np.cumsum(weight_pdf) obs = np.empty((n, self.n_features)) - for x in xrange(n): - rand = np.random.rand() - c = (weight_cdf > rand).argmax() - if self._cvtype == 'tied': - cv = self._covars - else: - cv = self._covars[c] - obs[x] = sample_gaussian(self._means[c], cv, self._cvtype) + rand = np.random.rand(n) + # decide which component to use for each sample + c = weight_cdf.searchsorted(rand) + # for each component, generate all needed samples + for cc in xrange(self._n_states): + ccc = (c==cc) # occurences of current component in obs + nccc = ccc.sum() # number of those occurences + if nccc > 0: + if self._cvtype == 'tied': + cv = self._covars + else: + cv = self._covars[cc] + obs[ccc] = sample_gaussian(self._means[cc], cv, self._cvtype, nccc).T return obs def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',