diff --git a/examples/gpml/plot_gpml_regression.py b/examples/gpml/plot_gpml_regression.py
index cf24567fe58ae6675742520544518b8485cbc81c..3d2fc342cf4d487acbd9541ccce559111a819d85 100644
--- a/examples/gpml/plot_gpml_regression.py
+++ b/examples/gpml/plot_gpml_regression.py
@@ -48,11 +48,11 @@ for k in range(len(corrs)):
     y, MSE = aGaussianProcessModel.predict(x, eval_MSE=True)
     sigma = np.sqrt(MSE)
     
-    # Compute the score function on a grid of the autocorrelation parameter space
+    # Compute the reduced likelihood function on a grid of the autocorrelation parameter space
     theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL[0,0]), np.log10(aGaussianProcessModel.thetaU[0,0]), 100)
-    score_values = []
+    reduced_likelihood_function_values = []
     for t in theta_values:
-        score_values.append(aGaussianProcessModel.score(theta=t)[0])
+        reduced_likelihood_function_values.append(aGaussianProcessModel.reduced_likelihood_function(theta=t)[0])
     
     fig = pl.figure()
     
@@ -67,9 +67,9 @@ for k in range(len(corrs)):
     pl.ylim(-10, 20)
     pl.legend(loc='upper left')
     
-    # Plot the score function
+    # Plot the reduced likelihood function
     ax = fig.add_subplot(212)
-    pl.plot(theta_values, score_values, colors[k]+'-')
+    pl.plot(theta_values, reduced_likelihood_function_values, colors[k]+'-')
     pl.xlabel(u'$\\theta$')
     pl.ylabel(u'Score')
     pl.xscale('log')
diff --git a/scikits/learn/gpml/gaussian_process_model.py b/scikits/learn/gpml/gaussian_process_model.py
index 368dbbb40d0d8fdc9a3d3fec084eba02c1b6fe69..66ad5610ebd23882b147e94053a00e7810bf43c2 100644
--- a/scikits/learn/gpml/gaussian_process_model.py
+++ b/scikits/learn/gpml/gaussian_process_model.py
@@ -70,10 +70,10 @@ class GaussianProcessModel(BaseEstimator):
 
     pl.figure(2)
     theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL),np.log10(aGaussianProcessModel.thetaU),100)
-    score_values = []
+    reduced_likelihood_function_values = []
     for t in theta_values:
-            score_values.append(aGaussianProcessModel.score(theta=t)[0])
-    pl.plot(theta_values, score_values)
+            reduced_likelihood_function_values.append(aGaussianProcessModel.reduced_likelihood_function(theta=t)[0])
+    pl.plot(theta_values, reduced_likelihood_function_values)
     pl.xlabel(u'$\\theta$')
     pl.ylabel(u'Score')
     pl.xscale('log')
@@ -317,16 +317,16 @@ class GaussianProcessModel(BaseEstimator):
             # Maximum Likelihood Estimation of the parameters
             if self.verbose:
                 print "Performing Maximum Likelihood Estimation of the autocorrelation parameters..."
-            self.theta, self.score_value, self.par  = self.__arg_max_score__()
-            if np.isinf(self.score_value) :
+            self.theta, self.reduced_likelihood_function_value, self.par  = self.arg_max_reduced_likelihood_function()
+            if np.isinf(self.reduced_likelihood_function_value) :
                 raise Exception, "Bad parameter region. Try increasing upper bound"
         else:
             # Given parameters
             if self.verbose:
                 print "Given autocorrelation parameters. Computing Gaussian Process Model parameters..."
             self.theta = self.theta0
-            self.score_value, self.par = self.score()
-            if np.isinf(self.score_value):
+            self.reduced_likelihood_function_value, self.par = self.reduced_likelihood_function()
+            if np.isinf(self.reduced_likelihood_function_value):
                 raise Exception, "Bad point. Try increasing theta0"
         
         if self.storage_mode == 'light':
@@ -424,7 +424,7 @@ class GaussianProcessModel(BaseEstimator):
                     # Light storage mode (need to recompute C, F, Ft and G)
                     if self.verbose:
                         print "This GaussianProcessModel used light storage mode at instanciation. Need to recompute autocorrelation matrix..."
-                    score_value, par = self.score()
+                    reduced_likelihood_function_value, par = self.reduced_likelihood_function()
                 
                 rt = linalg.solve(np.matrix(par['C']), np.matrix(r))
                 if self.beta0 is None:
@@ -471,10 +471,10 @@ class GaussianProcessModel(BaseEstimator):
     
     
     
-    def score(self, theta = None):
+    def reduced_likelihood_function(self, theta = None):
         """
-        This function determines the BLUP parameters and evaluates the score function for the given autocorrelation parameters theta.
-        Maximizing this score function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the
+        This function determines the BLUP parameters and evaluates the reduced likelihood function for the given autocorrelation parameters theta.
+        Maximizing this function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the
         assumed joint Gaussian distribution of the observations y evaluated onto the design of experiments X.
         
         Parameters
@@ -494,15 +494,15 @@ class GaussianProcessModel(BaseEstimator):
             par['Ft'] : Solution of the linear equation system : [R] x Ft = F
             par['G'] : QR decomposition of the matrix Ft.
             par['detR'] : Determinant of the correlation matrix raised at power 1/n_samples.
-        score_value : double
-            The value of the score function associated to the given autocorrelation parameters theta.
+        reduced_likelihood_function_value : double
+            The value of the reduced likelihood function associated to the given autocorrelation parameters theta.
         """
         
         if theta is None:
             # Use built-in autocorrelation parameters
             theta = self.theta
         
-        score_value = -np.inf
+        reduced_likelihood_function_value = -np.inf
         par = { }
         
         # Retrieve data
@@ -539,7 +539,7 @@ class GaussianProcessModel(BaseEstimator):
         try:
             C = linalg.cholesky(R, lower=True)
         except linalg.LinAlgError:
-            return score_value, par
+            return reduced_likelihood_function_value, par
         
         # Get generalized least squares solution
         Ft = linalg.solve(C, F)
@@ -558,7 +558,7 @@ class GaussianProcessModel(BaseEstimator):
                 raise Exception, "F is too ill conditioned. Poor combination of regression model and observations."
             else:
                 # Ft is too ill conditioned
-                return score_value, par
+                return reduced_likelihood_function_value, par
         
         Yt = linalg.solve(C,self.y_)
         if self.beta0 is None:
@@ -572,7 +572,7 @@ class GaussianProcessModel(BaseEstimator):
         normalized_sigma2 = (np.array(rho)**2.).sum(axis=0)/n_samples
         # The determinant of R is equal to the squared product of the diagonal elements of its Cholesky decomposition C
         detR = (np.array(np.diag(C))**(2./n_samples)).prod()
-        score_value = - normalized_sigma2.sum() * detR
+        reduced_likelihood_function_value = - normalized_sigma2.sum() * detR
         par = { 'sigma2':normalized_sigma2 * self.y_sc[1]**2, \
                 'beta':beta, \
                 'gamma':linalg.solve(C.T,rho).T, \
@@ -580,23 +580,23 @@ class GaussianProcessModel(BaseEstimator):
                 'Ft':Ft, \
                 'G':G }
         
-        return score_value, par
+        return reduced_likelihood_function_value, par
 
-    def __arg_max_score__(self):
+    def arg_max_reduced_likelihood_function(self):
         """
-        This function estimates the autocorrelation parameters theta as the maximizer of the score function.
-        (Minimization of the opposite score function is used for convenience)
+        This function estimates the autocorrelation parameters theta as the maximizer of the reduced likelihood function.
+        (Minimization of the opposite reduced likelihood function is used for convenience)
         
         Parameters
         ----------
-        self: All parameters are stored in the Gaussian Process Model object.
+        self : All parameters are stored in the Gaussian Process Model object.
         
         Returns
         -------
         optimal_theta : array_like
-            The best set of autocorrelation parameters (the maximizer of the score function).
-        optimal_score_value : double
-            The optimal score function value.
+            The best set of autocorrelation parameters (the sought maximizer of the reduced likelihood function).
+        optimal_reduced_likelihood_function_value : double
+            The optimal reduced likelihood function value.
         optimal_par : dict
             The BLUP parameters associated to thetaOpt.
         """
@@ -610,7 +610,7 @@ class GaussianProcessModel(BaseEstimator):
         
         if self.optimizer == 'fmin_cobyla':
             
-            minus_score = lambda log10t: - self.score(theta = 10.**log10t)[0]
+            minus_reduced_likelihood_function = lambda log10t: - self.reduced_likelihood_function(theta = 10.**log10t)[0]
             
             constraints = []
             for i in range(self.theta0.size):
@@ -627,18 +627,18 @@ class GaussianProcessModel(BaseEstimator):
                     log10theta0 = np.log10(self.thetaL) + random.rand(self.theta0.size).reshape(self.theta0.shape) * np.log10(self.thetaU/self.thetaL)
                     theta0 = 10.**log10theta0
                 
-                log10_optimal_theta = optimize.fmin_cobyla(minus_score, np.log10(theta0), constraints, iprint=0)
+                log10_optimal_theta = optimize.fmin_cobyla(minus_reduced_likelihood_function, np.log10(theta0), constraints, iprint=0)
                 optimal_theta = 10.**log10_optimal_theta
-                optimal_minus_score_value, optimal_par = self.score(theta = optimal_theta)
-                optimal_score_value = - optimal_minus_score_value
+                optimal_minus_reduced_likelihood_function_value, optimal_par = self.reduced_likelihood_function(theta = optimal_theta)
+                optimal_reduced_likelihood_function_value = - optimal_minus_reduced_likelihood_function_value
                 
                 if k > 0:
-                    if optimal_score_value > best_optimal_score_value:
-                        best_optimal_score_value = optimal_score_value
+                    if optimal_reduced_likelihood_function_value > best_optimal_reduced_likelihood_function_value:
+                        best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value
                         best_optimal_par = optimal_par
                         best_optimal_theta = optimal_theta
                 else:
-                    best_optimal_score_value = optimal_score_value
+                    best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value
                     best_optimal_par = optimal_par
                     best_optimal_theta = optimal_theta
                 if self.verbose and self.random_start > 1:
@@ -650,4 +650,24 @@ class GaussianProcessModel(BaseEstimator):
             
             raise NotImplementedError, "This optimizer ('%s') is not implemented yet. Please contribute!" % self.optimizer
     
-        return best_optimal_theta, best_optimal_score_value, best_optimal_par
\ No newline at end of file
+        return best_optimal_theta, best_optimal_reduced_likelihood_function_value, best_optimal_par
+    
+    def score(self, X_test, y_test):
+        """
+        This score function returns the deviations of the Gaussian Process Model evaluated onto a test dataset.
+        
+        Parameters
+        ----------
+        X_test : array_like
+            The feature test dataset with shape (n_tests, n_features).
+            
+        y_test : array_like
+            The target test dataset (n_tests, ).
+        
+        Returns
+        -------
+        score_values : array_like
+            The deviations between the prediction and the targets : y_pred - y_test.
+        """
+        
+        return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test
\ No newline at end of file
diff --git a/scikits/learn/tests/test_gpml.py b/scikits/learn/tests/test_gpml.py
index 7327f81f521c5edf71c54ba96492b44c0b66d504..a7963e65e8601cb7d2cc80912840043b5dd31e7e 100644
--- a/scikits/learn/tests/test_gpml.py
+++ b/scikits/learn/tests/test_gpml.py
@@ -26,7 +26,7 @@ def test_regression_1d_x_sinx():
     
     assert (np.all(np.abs((y_pred - y) / y)  < 1e-6) and np.all(MSE  < 1e-6))
 
-def test_regression_diabetes():
+def test_regression_diabetes(n_jobs = 1, verbose = 0):
     """
     MLE estimation of a Gaussian Process model with an anisotropic squared exponential
     correlation model.
@@ -43,8 +43,7 @@ def test_regression_diabetes():
     gpm = gpml.GaussianProcessModel(corr=gpml.correxp2, theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y)
 
     gpm.thetaL, gpm.thetaU = None, None
-    score_func = lambda self, X_test, y_test: self.predict(X_test)[0]
-    y_pred = cross_val.cross_val_score(gpm, X, y=y, score_func=score_func, cv=cross_val.LeaveOneOut(y.size), n_jobs=1, verbose=0)
+    y_pred = cross_val.cross_val_score(gpm, X, y=y, cv=cross_val.LeaveOneOut(y.size), n_jobs=n_jobs, verbose=verbose) + y
     Q2 = metrics.explained_variance(y_pred, y)
     
-    assert Q2 > 0.
+    assert Q2 > 0.45