From 1418c7b5f1604fdd80bf235cdbc17e1a702d7fba Mon Sep 17 00:00:00 2001
From: Ron Weiss <ronweiss@gmail.com>
Date: Sun, 7 Nov 2010 16:15:55 -0500
Subject: [PATCH] add HMM.predict_proba

---
 scikits/learn/hmm.py | 19 +++++++++++++++++++
 1 file changed, 19 insertions(+)

diff --git a/scikits/learn/hmm.py b/scikits/learn/hmm.py
index 1fffddcda5..1191e5932d 100644
--- a/scikits/learn/hmm.py
+++ b/scikits/learn/hmm.py
@@ -231,6 +231,25 @@ class _BaseHMM(BaseEstimator):
         logprob, state_sequence = self.decode(obs, **kwargs)
         return state_sequence
 
+    def predict_proba(self, obs, **kwargs):
+        """Compute the posterior probability for each state in the model
+
+        Parameters
+        ----------
+        obs : array_like, shape (n, n_dim)
+            List of n_dim-dimensional data points.  Each row corresponds to a
+            single data point.
+
+        See eval() for a list of accepted keyword arguments.
+
+        Returns
+        -------
+        T : array-like, shape (n, n_states)
+            Returns the probability of the sample for each state in the model.
+        """
+        logprob, posteriors = self.eval(obs, **kwargs)
+        return posteriors
+
     def rvs(self, n=1):
         """Generate random samples from the model.
 
-- 
GitLab