Skip to content
Snippets Groups Projects
Commit 7c53b1ec authored by Fabian Pedregosa's avatar Fabian Pedregosa
Browse files

Second patch by f0k.

parent bbb60337
Branches
Tags
No related merge requests found
......@@ -405,17 +405,17 @@ class GMM(BaseEstimator):
obs = np.empty((n, self.n_features))
rand = np.random.rand(n)
# decide which component to use for each sample
c = weight_cdf.searchsorted(rand)
comps = weight_cdf.searchsorted(rand)
# for each component, generate all needed samples
for cc in xrange(self._n_states):
ccc = (c==cc) # occurences of current component in obs
nccc = ccc.sum() # number of those occurences
if nccc > 0:
for comp in xrange(self._n_states):
comp_in_obs = (comp==comps) # occurrences of current component in obs
num_comp_in_obs = comp_in_obs.sum() # number of those occurrences
if num_comp_in_obs > 0:
if self._cvtype == 'tied':
cv = self._covars
else:
cv = self._covars[cc]
obs[ccc] = sample_gaussian(self._means[cc], cv, self._cvtype, nccc).T
cv = self._covars[comp]
obs[comp_in_obs] = sample_gaussian(self._means[comp], cv, self._cvtype, num_comp_in_obs).T
return obs
def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment