Skip to content
Snippets Groups Projects
Commit 98ed29fb authored by Alexandre Gramfort's avatar Alexandre Gramfort
Browse files

adding "proba_predict" method in LDA (should be done for SVC too)

git-svn-id: https://scikit-learn.svn.sourceforge.net/svnroot/scikit-learn/trunk@618 22fbfee3-77ab-4535-9bad-27d1bd3bc7d8
parent 69d1b085
No related branches found
No related tags found
No related merge requests found
...@@ -134,7 +134,12 @@ class LDA(object): ...@@ -134,7 +134,12 @@ class LDA(object):
V = V[:X.shape[0], :] V = V[:X.shape[0], :]
return S_sort, V return S_sort, V
def predict(self, X, posterior=False): def predict(self, X):
probas = self.proba_predict(X)
y_pred = self.classes[probas.argmax(1)]
return y_pred
def proba_predict(self, X):
#Ensure X is an array #Ensure X is an array
X = np.asarray(X) X = np.asarray(X)
scaling = self.scaling scaling = self.scaling
...@@ -146,13 +151,9 @@ class LDA(object): ...@@ -146,13 +151,9 @@ class LDA(object):
# for each class k, compute the linear discrinant function(p. 87 Hastie) # for each class k, compute the linear discrinant function(p. 87 Hastie)
# of sphered (scaled data) # of sphered (scaled data)
dist = 0.5*np.sum(dm**2, 1) - np.log(self.priors) - np.dot(X,dm.T) dist = 0.5*np.sum(dm**2, 1) - np.log(self.priors) - np.dot(X,dm.T)
self.dist = dist
# take exp of min dist # take exp of min dist
dist = np.exp(-dist + dist.min(1).reshape(X.shape[0],1)) dist = np.exp(-dist + dist.min(1).reshape(X.shape[0],1))
# normalize by p(x)=sum_k p(x|k) # normalize by p(x)=sum_k p(x|k)
self.posteriors = dist / dist.sum(1).reshape(X.shape[0],1) probas = dist / dist.sum(1).reshape(X.shape[0],1)
# classify according to the maximun a posteriori # classify according to the maximun a posteriori
y_pred = self.classes[self.posteriors.argmax(1)] return probas
if posterior is True:
return y_pred, self.posteriors
return y_pred
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment