diff --git a/scikits/learn/hmm.py b/scikits/learn/hmm.py index 1fffddcda54e2c849bfd14d864298082df90bf14..1191e5932dafc9c9f11c66a1e3bcde795d25025b 100644 --- a/scikits/learn/hmm.py +++ b/scikits/learn/hmm.py @@ -231,6 +231,25 @@ class _BaseHMM(BaseEstimator): logprob, state_sequence = self.decode(obs, **kwargs) return state_sequence + def predict_proba(self, obs, **kwargs): + """Compute the posterior probability for each state in the model + + Parameters + ---------- + obs : array_like, shape (n, n_dim) + List of n_dim-dimensional data points. Each row corresponds to a + single data point. + + See eval() for a list of accepted keyword arguments. + + Returns + ------- + T : array-like, shape (n, n_states) + Returns the probability of the sample for each state in the model. + """ + logprob, posteriors = self.eval(obs, **kwargs) + return posteriors + def rvs(self, n=1): """Generate random samples from the model.