From bbb60337c26403b946fb41ca3ecd804ccd916855 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Mon, 21 Feb 2011 10:29:05 +0100 Subject: [PATCH] Improve performance of GMM sampling Patch by f0k. --- scikits/learn/mixture.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/scikits/learn/mixture.py b/scikits/learn/mixture.py index 9798f8552b..af843a06b0 100644 --- a/scikits/learn/mixture.py +++ b/scikits/learn/mixture.py @@ -99,7 +99,7 @@ def sample_gaussian(mean, covar, cvtype='diag', n=1): Returns ------- - obs : array, shape (n, n_features) + obs : array, shape (n_features, n) Randomly generated sample """ ndim = len(mean) @@ -403,14 +403,19 @@ class GMM(BaseEstimator): weight_cdf = np.cumsum(weight_pdf) obs = np.empty((n, self.n_features)) - for x in xrange(n): - rand = np.random.rand() - c = (weight_cdf > rand).argmax() - if self._cvtype == 'tied': - cv = self._covars - else: - cv = self._covars[c] - obs[x] = sample_gaussian(self._means[c], cv, self._cvtype) + rand = np.random.rand(n) + # decide which component to use for each sample + c = weight_cdf.searchsorted(rand) + # for each component, generate all needed samples + for cc in xrange(self._n_states): + ccc = (c==cc) # occurences of current component in obs + nccc = ccc.sum() # number of those occurences + if nccc > 0: + if self._cvtype == 'tied': + cv = self._covars + else: + cv = self._covars[cc] + obs[ccc] = sample_gaussian(self._means[cc], cv, self._cvtype, nccc).T return obs def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc', -- GitLab