From 1418c7b5f1604fdd80bf235cdbc17e1a702d7fba Mon Sep 17 00:00:00 2001 From: Ron Weiss <ronweiss@gmail.com> Date: Sun, 7 Nov 2010 16:15:55 -0500 Subject: [PATCH] add HMM.predict_proba --- scikits/learn/hmm.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scikits/learn/hmm.py b/scikits/learn/hmm.py index 1fffddcda5..1191e5932d 100644 --- a/scikits/learn/hmm.py +++ b/scikits/learn/hmm.py @@ -231,6 +231,25 @@ class _BaseHMM(BaseEstimator): logprob, state_sequence = self.decode(obs, **kwargs) return state_sequence + def predict_proba(self, obs, **kwargs): + """Compute the posterior probability for each state in the model + + Parameters + ---------- + obs : array_like, shape (n, n_dim) + List of n_dim-dimensional data points. Each row corresponds to a + single data point. + + See eval() for a list of accepted keyword arguments. + + Returns + ------- + T : array-like, shape (n, n_states) + Returns the probability of the sample for each state in the model. + """ + logprob, posteriors = self.eval(obs, **kwargs) + return posteriors + def rvs(self, n=1): """Generate random samples from the model. -- GitLab