diff --git a/scikits/learn/bayes/bayes.py b/scikits/learn/bayes/bayes.py
index 01e21fe86b590af5c523c891ff6f075317b223c8..d75492b413584f4add3c27d6cbdc06679c5b4ff8 100755
--- a/scikits/learn/bayes/bayes.py
+++ b/scikits/learn/bayes/bayes.py
@@ -25,7 +25,7 @@ def fast_logdet(A):
 
 
 
-def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) :
+def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-12,ll_bool=False) :
     """
     Bayesian ridge regression. Optimize the regularization parameter alpha
     within a simple bayesian framework (MAP).
@@ -39,15 +39,17 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) :
 	target
     step_th : int (defaut is 300)
 	      Stop the algorithm after a given number of steps.
-    th_w : float (defaut is 1.e-6)
+    th_w : float (defaut is 1.e-12)
 	   Stop the algorithm if w has converged.
-    ll_bool  : boolean (default is True).
+    ll_bool  : boolean (default is False).
 	       If True, compute the log-likelihood at each step of the model.
 
     Returns
     -------
     w : numpy array of shape (dim)
          mean of the weights distribution.
+    log_likelihood : list of float of size steps.
+		     Compute (if asked) the log-likelihood of the model.
    
     Examples
     --------
@@ -64,7 +66,9 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) :
     beta = 1./np.var(Y)
     alpha = 1.0
 
-    log_likelihood = []
+    log_likelihood = None
+    if ll_bool :
+      log_likelihood = []
     has_converged = False
     gram = np.dot(X.T, X)
     ones = np.eye(gram.shape[1])
@@ -89,7 +93,6 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) :
         step_th -= 1
 
 
-
 	# convergence : compare w
 	has_converged =  (np.sum(np.abs(w-old_w))<th_w)
         old_w = w
@@ -104,11 +107,11 @@ def bayesian_ridge( X , Y, step_th=300,th_w = 1.e-6,ll_bool=True) :
 	  ll -= X.shape[0]*np.log(2*np.pi)
 	  log_likelihood.append(ll)
 
-    return w,log_likelihood[1:]
+    return w,log_likelihood
 
 
 
-class BayessianRegression(object):
+class BayesianRegression(object):
     """
     Encapsulate various bayesian regression algorithms
     """
@@ -119,9 +122,9 @@ class BayessianRegression(object):
     def fit(self, X, Y):
         X = np.asanyarray(X, dtype=np.float)
         Y = np.asanyarray(Y, dtype=np.float)
-        self.w  = bayesian_ridge(X, Y)
+        self.w,self.log_likelihood = bayesian_ridge(X, Y)
 
     def predict(self, T):
-        T = np.asanyarray(T)
-        # I think this is wrong
         return np.dot(T, self.w)
+
+
diff --git a/scikits/learn/bayes/tests/test_bayes.py b/scikits/learn/bayes/tests/test_bayes.py
index 64879acde3f175b6656d3c99fd24979511e37671..5c1eaa7511e95839ad2301c19d65ac629af69339 100644
--- a/scikits/learn/bayes/tests/test_bayes.py
+++ b/scikits/learn/bayes/tests/test_bayes.py
@@ -1,19 +1,33 @@
 import numpy as np
-from scikits.learn.bayes.bayes import bayesian_ridge, BayessianRegression
+from scikits.learn.bayes.bayes import bayesian_ridge, BayesianRegression
 from numpy.testing import assert_array_almost_equal
-
-X = np.array([[1], [2]])
-Y = np.array([1, 2])
+from  scikits.learn.datasets.samples_generator import linear
 
 def test_toy():
-    w = bayesian_ridge(X, Y)
+    X = np.array([[1], [2]])
+    Y = np.array([1, 2])
+    w ,log_likelihood = bayesian_ridge(X, Y)
     assert_array_almost_equal(w, [1])
 
 def test_toy_object():
     """
-    Test BayessianRegression classifier
+    Test BayesianRegression classifier
     """
-    clf = BayessianRegression()
+    X = np.array([[1], [2]])
+    Y = np.array([1, 2])
+    clf = BayesianRegression()
     clf.fit(X, Y)
     Test = [[1], [2], [3], [4]]
     assert_array_almost_equal(clf.predict(Test), [1, 2, 3, 4]) # identity
+
+def test_simu_object():
+    """
+    Test BayesianRegression classifier with simulated linear data
+    """
+    X,Y  = linear.sparse_uncorrelated(nb_samples=100,nb_features=10)
+    clf = BayesianRegression()
+    clf.fit(X, Y)
+    Xtest,Ytest  = linear.sparse_uncorrelated(nb_samples=100,nb_features=10)
+    mse = np.mean((clf.predict(Xtest)-Ytest)**2)
+    assert(mse<2.)
+
diff --git a/scikits/learn/datasets/samples_generator/linear.py b/scikits/learn/datasets/samples_generator/linear.py
index 72fc93c5369c97451965947bd405c4ead6dfc5e6..cf649ea5a06a19fd4d5c5ab0865eff8214a4ddda 100755
--- a/scikits/learn/datasets/samples_generator/linear.py
+++ b/scikits/learn/datasets/samples_generator/linear.py
@@ -1,12 +1,12 @@
 import numpy as np
-
+import numpy.random as nr
 
 def sparse_uncorrelated(nb_samples=100,nb_features=10):
     """
     Function creating simulated data with sparse uncorrelated design.
     (cf.Celeux et al. 2009,  Bayesian regularization in regression)
     X = NR.normal(0,1)
-    Y = NR.normal(2+X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7])
+    Y = NR.normal(X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7])
     The number of features is at least 10.
 
     Parameters
@@ -23,6 +23,6 @@ def sparse_uncorrelated(nb_samples=100,nb_features=10):
     Y : numpy array of shape (nb_samples)
     """
     X = nr.normal(loc=0,scale=1,size=(nb_samples,nb_features))
-    Y = nr.normal(loc=2+x[:,2]+2*x[:,3]-2*x[:,6]-1.5*x[:,7],
+    Y = nr.normal(loc=X[:,2]+2*X[:,3]-2*X[:,6]-1.5*X[:,7],
     scale=np.ones(nb_samples))
     return X,Y
\ No newline at end of file