diff --git a/scikits/learn/bayes/bayes.py b/scikits/learn/bayes/bayes.py index 01e21fe86b590af5c523c891ff6f075317b223c8..d75492b413584f4add3c27d6cbdc06679c5b4ff8 100755 --- a/scikits/learn/bayes/bayes.py +++ b/scikits/learn/bayes/bayes.py @@ -25,7 +25,7 @@ def fast_logdet(A): -def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) : +def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-12,ll_bool=False) : """ Bayesian ridge regression. Optimize the regularization parameter alpha within a simple bayesian framework (MAP). @@ -39,15 +39,17 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) : target step_th : int (defaut is 300) Stop the algorithm after a given number of steps. - th_w : float (defaut is 1.e-6) + th_w : float (defaut is 1.e-12) Stop the algorithm if w has converged. - ll_bool : boolean (default is True). + ll_bool : boolean (default is False). If True, compute the log-likelihood at each step of the model. Returns ------- w : numpy array of shape (dim) mean of the weights distribution. + log_likelihood : list of float of size steps. + Compute (if asked) the log-likelihood of the model. Examples -------- @@ -64,7 +66,9 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) : beta = 1./np.var(Y) alpha = 1.0 - log_likelihood = [] + log_likelihood = None + if ll_bool : + log_likelihood = [] has_converged = False gram = np.dot(X.T, X) ones = np.eye(gram.shape[1]) @@ -89,7 +93,6 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) : step_th -= 1 - # convergence : compare w has_converged = (np.sum(np.abs(w-old_w))<th_w) old_w = w @@ -104,11 +107,11 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) : ll -= X.shape[0]*np.log(2*np.pi) log_likelihood.append(ll) - return w,log_likelihood[1:] + return w,log_likelihood -class BayessianRegression(object): +class BayesianRegression(object): """ Encapsulate various bayesian regression algorithms """ @@ -119,9 +122,9 @@ class BayessianRegression(object): def fit(self, X, Y): X = np.asanyarray(X, dtype=np.float) Y = np.asanyarray(Y, dtype=np.float) - self.w = bayesian_ridge(X, Y) + self.w,self.log_likelihood = bayesian_ridge(X, Y) def predict(self, T): - T = np.asanyarray(T) - # I think this is wrong return np.dot(T, self.w) + + diff --git a/scikits/learn/bayes/tests/test_bayes.py b/scikits/learn/bayes/tests/test_bayes.py index 64879acde3f175b6656d3c99fd24979511e37671..5c1eaa7511e95839ad2301c19d65ac629af69339 100644 --- a/scikits/learn/bayes/tests/test_bayes.py +++ b/scikits/learn/bayes/tests/test_bayes.py @@ -1,19 +1,33 @@ import numpy as np -from scikits.learn.bayes.bayes import bayesian_ridge, BayessianRegression +from scikits.learn.bayes.bayes import bayesian_ridge, BayesianRegression from numpy.testing import assert_array_almost_equal - -X = np.array([[1], [2]]) -Y = np.array([1, 2]) +from scikits.learn.datasets.samples_generator import linear def test_toy(): - w = bayesian_ridge(X, Y) + X = np.array([[1], [2]]) + Y = np.array([1, 2]) + w ,log_likelihood = bayesian_ridge(X, Y) assert_array_almost_equal(w, [1]) def test_toy_object(): """ - Test BayessianRegression classifier + Test BayesianRegression classifier """ - clf = BayessianRegression() + X = np.array([[1], [2]]) + Y = np.array([1, 2]) + clf = BayesianRegression() clf.fit(X, Y) Test = [[1], [2], [3], [4]] assert_array_almost_equal(clf.predict(Test), [1, 2, 3, 4]) # identity + +def test_simu_object(): + """ + Test BayesianRegression classifier with simulated linear data + """ + X,Y = linear.sparse_uncorrelated(nb_samples=100,nb_features=10) + clf = BayesianRegression() + clf.fit(X, Y) + Xtest,Ytest = linear.sparse_uncorrelated(nb_samples=100,nb_features=10) + mse = np.mean((clf.predict(Xtest)-Ytest)**2) + assert(mse<2.) + diff --git a/scikits/learn/datasets/samples_generator/linear.py b/scikits/learn/datasets/samples_generator/linear.py index 72fc93c5369c97451965947bd405c4ead6dfc5e6..cf649ea5a06a19fd4d5c5ab0865eff8214a4ddda 100755 --- a/scikits/learn/datasets/samples_generator/linear.py +++ b/scikits/learn/datasets/samples_generator/linear.py @@ -1,12 +1,12 @@ import numpy as np - +import numpy.random as nr def sparse_uncorrelated(nb_samples=100,nb_features=10): """ Function creating simulated data with sparse uncorrelated design. (cf.Celeux et al. 2009, Bayesian regularization in regression) X = NR.normal(0,1) - Y = NR.normal(2+X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7]) + Y = NR.normal(X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7]) The number of features is at least 10. Parameters @@ -23,6 +23,6 @@ def sparse_uncorrelated(nb_samples=100,nb_features=10): Y : numpy array of shape (nb_samples) """ X = nr.normal(loc=0,scale=1,size=(nb_samples,nb_features)) - Y = nr.normal(loc=2+x[:,2]+2*x[:,3]-2*x[:,6]-1.5*x[:,7], + Y = nr.normal(loc=X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7], scale=np.ones(nb_samples)) return X,Y \ No newline at end of file