diff --git a/scikits/learn/mixture.py b/scikits/learn/mixture.py index b5482d8973d225ce35ec3173947272cfaace1eaa..2f66c780f9ad17a53e88abeacd4ddf68c5f3bad5 100644 --- a/scikits/learn/mixture.py +++ b/scikits/learn/mixture.py @@ -134,9 +134,10 @@ class GMM(BaseEstimator): Parameters ---------- - n_states : int - Number of mixture components. - cvtype : string (read-only) + n_states : int, optional + Number of mixture components. Defaults to 1. + + cvtype : string (read-only), optional String describing the type of covariance parameters to use. Must be one of 'spherical', 'tied', 'diag', 'full'. Defaults to 'diag'. @@ -386,36 +387,39 @@ class GMM(BaseEstimator): logprob, posteriors = self.eval(X) return posteriors - def rvs(self, n=1): + def rvs(self, n_samples=1): """Generate random samples from the model. Parameters ---------- - n : int - Number of samples to generate. + n_samples : int, optional + Number of samples to generate. Defaults to 1. Returns ------- - obs : array_like, shape (n, n_features) + obs : array_like, shape (n_samples, n_features) List of samples """ weight_pdf = self.weights weight_cdf = np.cumsum(weight_pdf) - obs = np.empty((n, self.n_features)) - rand = np.random.rand(n) + obs = np.empty((n_samples, self.n_features)) + rand = np.random.rand(n_samples) # decide which component to use for each sample comps = weight_cdf.searchsorted(rand) # for each component, generate all needed samples for comp in xrange(self._n_states): - comp_in_obs = (comp==comps) # occurrences of current component in obs - num_comp_in_obs = comp_in_obs.sum() # number of those occurrences + # occurrences of current component in obs + comp_in_obs = (comp==comps) + # number of those occurrences + num_comp_in_obs = comp_in_obs.sum() if num_comp_in_obs > 0: if self._cvtype == 'tied': cv = self._covars else: cv = self._covars[comp] - obs[comp_in_obs] = sample_gaussian(self._means[comp], cv, self._cvtype, num_comp_in_obs).T + obs[comp_in_obs] = sample_gaussian( + self._means[comp], cv, self._cvtype, num_comp_in_obs).T return obs def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',