diff --git a/examples/gaussian_process/plot_gp_diabetes_dataset.py b/examples/gaussian_process/plot_gp_diabetes_dataset.py
index dafc0867f258fb02a3efda1b2d74d299197ecffc..219915c6c886cf0b8145a8a7a89a361a2c811e8a 100644
--- a/examples/gaussian_process/plot_gp_diabetes_dataset.py
+++ b/examples/gaussian_process/plot_gp_diabetes_dataset.py
@@ -2,69 +2,52 @@
 # -*- coding: utf-8 -*-
 
 """
-=========================================================================
+========================================================================
 Gaussian Processes regression: goodness-of-fit on the 'diabetes' dataset
-=========================================================================
+========================================================================
 
 This example consists in fitting a Gaussian Process model onto the diabetes
 dataset.
-WARNING: This is quite time consuming (about 2 minutes overall runtime).
 
-The corelation parameters are given in order to maximize the generalization
-capacity of the model. We assumed an anisotropic squared exponential
-correlation model with a constant regression model. We also used a
-nugget = 1e-2 in order to account for the (strong) noise in the targets.
+The correlation parameters are determined by means of maximum likelihood
+estimation (MLE). An anisotropic squared exponential correlation model with a
+constant regression model are assumed. We also used a nugget = 1e-2 in order to
+account for the (strong) noise in the targets.
 
-The figure is a goodness-of-fit plot obtained using leave-one-out predictions
-of the Gaussian Process model. Based on these predictions, we compute an
-explained variance error (Q2).
+We compute then compute a cross-validation estimate of the coefficient of
+determination (R2) without reperforming MLE, using the set of correlation
+parameters found on the whole dataset.
 """
+print __doc__
 
 # Author: Vincent Dubourg <vincent.dubourg@gmail.com>
 # License: BSD style
 
-from scikits.learn import datasets, cross_val, metrics
+import numpy as np
+from scikits.learn import datasets
 from scikits.learn.gaussian_process import GaussianProcess
-from matplotlib import pyplot as pl
-
-# Print the docstring
-print __doc__
+from scikits.learn.cross_val import cross_val_score, KFold
+from scikits.learn.metrics import r2_score
 
 # Load the dataset from scikits' data sets
 diabetes = datasets.load_diabetes()
-X, y = diabetes['data'], diabetes['target']
+X, y = diabetes.data, diabetes.target
 
 # Instanciate a GP model
 gp = GaussianProcess(regr='constant', corr='absolute_exponential',
                      theta0=[1e-4] * 10, thetaL=[1e-12] * 10,
-                     thetaU=[1e-2] * 10, nugget=1e-2, optimizer='Welch',
-                     verbose=False)
+                     thetaU=[1e-2] * 10, nugget=1e-2, optimizer='Welch')
 
-# Fit the GP model to the data
+# Fit the GP model to the data performing maximum likelihood estimation
 gp.fit(X, y)
-gp.theta0 = gp.theta
-gp.thetaL = None
-gp.thetaU = None
-gp.verbose = False
-
-# Estimate the leave-one-out predictions using the cross_val module
-n_jobs = 2 # the distributing capacity available on the machine
-y_pred = y + cross_val.cross_val_score(gp, X, y=y,
-                                   cv=cross_val.LeaveOneOut(y.size),
-                                   n_jobs=n_jobs,
-                                ).ravel()
 
-# Compute the empirical explained variance
-Q2 = metrics.explained_variance_score(y, y_pred)
+# Deactivate maximum likelihood estimation for the cross-validation loop
+gp.theta0 = gp.theta # Given correlation parameter = MLE
+gp.thetaL, gp.thetaU = None, None # None bounds deactivate MLE
 
-# Goodness-of-fit plot
-pl.figure()
-pl.title('Goodness-of-fit plot (Q2 = %1.2e)' % Q2)
-pl.plot(y, y_pred, 'r.', label='Leave-one-out')
-pl.plot(y, gp.predict(X), 'k.', label='Whole dataset (nugget=1e-2)')
-pl.plot([y.min(), y.max()], [y.min(), y.max()], 'k--')
-pl.xlabel('Observations')
-pl.ylabel('Predictions')
-pl.legend(loc='upper left')
-pl.axis('tight')
-pl.show()
+# Perform a cross-validation estimate of the coefficient of determination using
+# the cross_val module using all CPUs available on the machine
+K = 20 # folds
+R2 = cross_val_score(gp, X, y=y, cv=KFold(y.size, K), n_jobs=-1).mean()
+print("The %d-Folds estimate of the coefficient of determination is R2 = %s"
+    % (K, R2))
diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
index 5c1aa06bf001e88c5e7e123143e4b3aae817414d..329f2cf9322464ca16d37451b1d376e684ecf3c2 100644
--- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
+++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
@@ -14,6 +14,7 @@ respect to the remaining uncertainty in the prediction. The red and blue lines
 corresponds to the 95% confidence interval on the prediction of the zero level
 set.
 """
+print __doc__
 
 # Author: Vincent Dubourg <vincent.dubourg@gmail.com>
 # License: BSD style
@@ -24,22 +25,19 @@ from scikits.learn.gaussian_process import GaussianProcess
 from matplotlib import pyplot as pl
 from matplotlib import cm
 
-# Print the docstring
-print __doc__
-
 # Standard normal distribution functions
-Grv = stats.distributions.norm()
-phi = lambda x: Grv.pdf(x)
-PHI = lambda x: Grv.cdf(x)
-PHIinv = lambda x: Grv.ppf(x)
+phi = stats.distributions.norm().pdf
+PHI = stats.distributions.norm().cdf
+PHIinv = stats.distributions.norm().ppf
 
 # A few constants
 lim = 8
-b, kappa, e = 5, .5, .1
 
-# The function to predict (classification will then consist in predicting
-# wheter g(x) <= 0 or not)
-g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2.
+
+def g(x):
+    """The function to predict (classification will then consist in predicting
+    whether g(x) <= 0 or not)"""
+    return 5. - x[:, 1] - .5 * x[:, 0] ** 2.
 
 # Design of experiments
 X = np.array([[-4.61611719, -6.00099547],
diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py
index 90bf120ca998fba18dc5fc5def2d9388453940b6..66c27fe493080be1b48f0e714ebc534fbb6e7019 100644
--- a/examples/gaussian_process/plot_gp_regression.py
+++ b/examples/gaussian_process/plot_gp_regression.py
@@ -2,9 +2,9 @@
 # -*- coding: utf-8 -*-
 
 """
-=================================================================
+=========================================================
 Gaussian Processes regression: basic introductory example
-=================================================================
+=========================================================
 
 A simple one-dimensional regression exercise with a cubic correlation
 model whose parameters are estimated using the maximum likelihood principle.
@@ -13,6 +13,7 @@ The figure illustrates the interpolating property of the Gaussian Process
 model as well as its probabilistic nature in the form of a pointwise 95%
 confidence interval.
 """
+print __doc__
 
 # Author: Vincent Dubourg <vincent.dubourg@gmail.com>
 # License: BSD style
@@ -21,11 +22,10 @@ import numpy as np
 from scikits.learn.gaussian_process import GaussianProcess
 from matplotlib import pyplot as pl
 
-# Print the docstring
-print __doc__
 
-# The function to predict
-f = lambda x: x * np.sin(x)
+def f(x):
+    """The function to predict."""
+    return x * np.sin(x)
 
 # The design of experiments
 X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
diff --git a/scikits/learn/gaussian_process/correlation_models.py b/scikits/learn/gaussian_process/correlation_models.py
index 6ad2807b21bde73038673c24ffd87506e4dd05ed..52fc008935cc90230ef69a0a1177c6fe5792880f 100644
--- a/scikits/learn/gaussian_process/correlation_models.py
+++ b/scikits/learn/gaussian_process/correlation_models.py
@@ -39,23 +39,20 @@ def absolute_exponential(theta, d):
         An array with shape (n_eval, ) containing the values of the
         autocorrelation model.
     """
-
     theta = np.asanyarray(theta, dtype=np.float)
-    d = np.asanyarray(d, dtype=np.float)
+    d = np.abs(np.asanyarray(d, dtype=np.float))
 
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
         n_features = 1
+
     if theta.size == 1:
-        theta = np.repeat(theta, n_features)
+        return np.exp(- theta[0] * np.sum(d, axis=1))
     elif theta.size != n_features:
-        raise ValueError("Length of theta must be 1 or " + str(n_features))
-
-    td = - theta.reshape(1, n_features) * abs(d)
-    r = np.exp(np.sum(td, 1))
-
-    return r
+        raise ValueError("Length of theta must be 1 or %s" % n_features)
+    else:
+        return np.exp(- np.sum(theta.reshape(1, n_features) * d, axis=1))
 
 
 def squared_exponential(theta, d):
@@ -92,15 +89,13 @@ def squared_exponential(theta, d):
         n_features = d.shape[1]
     else:
         n_features = 1
+
     if theta.size == 1:
-        theta = np.repeat(theta, n_features)
+        return np.exp(- theta[0] * np.sum(d**2, axis=1))
     elif theta.size != n_features:
-        raise Exception("Length of theta must be 1 or " + str(n_features))
-
-    td = - theta.reshape(1, n_features) * d ** 2
-    r = np.exp(np.sum(td, 1))
-
-    return r
+        raise ValueError("Length of theta must be 1 or %s" % n_features)
+    else:
+        return np.exp(- np.sum(theta.reshape(1, n_features) * d**2, axis=1))
 
 
 def generalized_exponential(theta, d):
@@ -138,16 +133,17 @@ def generalized_exponential(theta, d):
         n_features = d.shape[1]
     else:
         n_features = 1
+
     lth = theta.size
     if n_features > 1 and lth == 2:
         theta = np.hstack([np.repeat(theta[0], n_features), theta[1]])
     elif lth != n_features + 1:
-        raise Exception("Length of theta must be 2 or " + str(n_features + 1))
+        raise Exception("Length of theta must be 2 or %s" % (n_features + 1))
     else:
         theta = theta.reshape(1, lth)
 
-    td = - theta[:, 0:-1].reshape(1, n_features) * abs(d) ** theta[:, -1]
-    r = np.exp(np.sum(td, 1))
+    td = theta[:, 0:-1].reshape(1, n_features) * np.abs(d) ** theta[:, -1]
+    r = np.exp(- np.sum(td, 1))
 
     return r
 
@@ -184,9 +180,7 @@ def pure_nugget(theta, d):
 
     n_eval = d.shape[0]
     r = np.zeros(n_eval)
-    # The ones on the diagonal of the correlation matrix are enforced within
-    # the KrigingModel instanciation to allow multiple design sites in this
-    # ordinary least squares context.
+    r[np.all(d == 0., axis=1)] = 1.
 
     return r
 
@@ -225,15 +219,15 @@ def cubic(theta, d):
         n_features = d.shape[1]
     else:
         n_features = 1
+
     lth = theta.size
     if  lth == 1:
-        theta = np.repeat(theta, n_features)[np.newaxis][:]
+        td = np.abs(d) * theta
     elif lth != n_features:
         raise Exception("Length of theta must be 1 or " + str(n_features))
     else:
-        theta = theta.reshape(1, n_features)
+        td = np.abs(d) * theta.reshape(1, n_features)
 
-    td = abs(d) * theta
     td[td > 1.] = 1.
     ss = 1. - td ** 2. * (3. - 2. * td)
     r = np.prod(ss, 1)
@@ -275,15 +269,15 @@ def linear(theta, d):
         n_features = d.shape[1]
     else:
         n_features = 1
+
     lth = theta.size
-    if  lth == 1:
-        theta = np.repeat(theta, n_features)[np.newaxis][:]
+    if lth == 1:
+        td = np.abs(d) * theta
     elif lth != n_features:
-        raise Exception("Length of theta must be 1 or " + str(n_features))
+        raise Exception("Length of theta must be 1 or %s" % n_features)
     else:
-        theta = theta.reshape(1, n_features)
+        td = np.abs(d) * theta.reshape(1, n_features)
 
-    td = abs(d) * theta
     td[td > 1.] = 1.
     ss = 1. - td
     r = np.prod(ss, 1)
diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py
index 0bce6699a6e97c61f136a64c8d1dc92ab1e1f232..3310fef044954070128a9c871d330fab060544ac 100644
--- a/scikits/learn/gaussian_process/gaussian_process.py
+++ b/scikits/learn/gaussian_process/gaussian_process.py
@@ -7,9 +7,11 @@
 
 import numpy as np
 from scipy import linalg, optimize, rand
-from ..base import BaseEstimator
+from ..base import BaseEstimator, RegressorMixin
 from . import regression_models as regression
 from . import correlation_models as correlation
+from ..cross_val import LeaveOneOut
+from ..externals.joblib import Parallel, delayed
 MACHINE_EPSILON = np.finfo(np.double).eps
 if hasattr(linalg, 'solve_triangular'):
     # only in scipy since 0.9
@@ -20,7 +22,79 @@ else:
         return linalg.solve(x, y)
 
 
-class GaussianProcess(BaseEstimator):
+def compute_componentwise_l1_cross_distances(X):
+    """
+    Computes the nonzero componentwise L1 cross-distances between the vectors
+    in X.
+
+    Parameters
+    ----------
+
+    X: array_like
+        An array with shape (n_samples, n_features)
+
+    Returns
+    -------
+
+    D: array with shape (n_samples * (n_samples - 1) / 2, n_features)
+        The array of componentwise L1 cross-distances.
+
+    ij: arrays with shape (n_samples * (n_samples - 1) / 2, 2)
+        The indices i and j of the vectors in X associated to the cross-
+        distances in D: D[k] = np.abs(X[ij[k, 0]] - Y[ij[k, 1]]).
+    """
+    X = np.atleast_2d(X)
+    n_samples, n_features = X.shape
+    n_nonzero_cross_dist = n_samples * (n_samples - 1) / 2
+    ij = np.zeros([n_nonzero_cross_dist, 2])
+    D = np.zeros([n_nonzero_cross_dist, n_features])
+    ll = np.array([-1])
+    for k in range(n_samples - 1):
+        ll = ll[-1] + 1 + range(n_samples - k - 1)
+        ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)],
+                                 [np.array(range(k + 1, n_samples)).T]]).T
+        D[ll] = np.abs(X[k] - X[(k + 1):n_samples])
+
+    return D, ij.astype(np.int)
+
+
+def compute_componentwise_l1_pairwise_distances(X, Y):
+    """
+    Computes the componentwise L1 pairwise-distances between the vectors
+    in X and Y.
+
+    Parameters
+    ----------
+
+    X: array_like
+        An array with shape (n_samples_X, n_features)
+
+    Y: array_like, optional
+        An array with shape (n_samples_Y, n_features).
+
+    Returns
+    -------
+
+    D: array with shape (n_samples_X * n_samples_Y, n_features)
+        The array of componentwise L1 pairwise-distances.
+    """
+    X, Y = np.atleast_2d(X), np.atleast_2d(Y)
+    n_samples_X, n_features_X = X.shape
+    n_samples_Y, n_features_Y = Y.shape
+    if n_features_X != n_features_Y:
+        raise Exception("X and Y should have the same number of features!")
+    else:
+        n_features = n_features_X
+    D = np.zeros([n_samples_X * n_samples_Y, n_features])
+    kk = np.arange(n_samples_Y).astype(np.int)
+    for k in range(n_samples_X):
+        D[kk] = X[k] - Y
+        kk = kk + n_samples_Y
+
+    return D
+
+
+class GaussianProcess(BaseEstimator, RegressorMixin):
     """
     The Gaussian Process model class.
 
@@ -85,7 +159,7 @@ class GaussianProcess(BaseEstimator):
         it uses theta0.
 
     normalize : boolean, optional
-        Design sites X and observations y are centered and reduced wrt
+        Input X and observations y are centered and reduced wrt
         means and standard deviations estimated from the n_samples
         observations provided.
         Default is normalize = True so that data is normalized to ease
@@ -187,8 +261,8 @@ class GaussianProcess(BaseEstimator):
         Parameters
         ----------
         X : double array_like
-            An array with shape (n_samples, n_features) with the design sites
-            at which observations were made.
+            An array with shape (n_samples, n_features) with the input at which
+            observations were made.
 
         y : double array_like
             An array with shape (n_features, ) with the observations of the
@@ -206,7 +280,7 @@ class GaussianProcess(BaseEstimator):
 
         # Force data to 2D numpy.array
         X = np.atleast_2d(X)
-        y = np.asanyarray(y)[:, np.newaxis]
+        y = np.asanyarray(y).ravel()[:, np.newaxis]
 
         # Check shapes of DOE & observations
         n_samples_X, n_features = X.shape
@@ -219,32 +293,24 @@ class GaussianProcess(BaseEstimator):
 
         # Normalize data or don't
         if self.normalize:
-            mean_X = np.mean(X, axis=0)
-            std_X = np.std(X, axis=0)
-            mean_y = np.mean(y, axis=0)
-            std_y = np.std(y, axis=0)
-            std_X[std_X == 0.] = 1.
-            std_y[std_y == 0.] = 1.
+            X_mean = np.mean(X, axis=0)
+            X_std = np.std(X, axis=0)
+            y_mean = np.mean(y, axis=0)
+            y_std = np.std(y, axis=0)
+            X_std[X_std == 0.] = 1.
+            y_std[y_std == 0.] = 1.
+            # center and scale X if necessary
+            X = (X - X_mean) / X_std
+            y = (y - y_mean) / y_std
         else:
-            mean_X = np.array([0.])
-            std_X = np.array([1.])
-            mean_y = np.array([0.])
-            std_y = np.array([1.])
-
-        X = (X - mean_X) / std_X
-        y = (y - mean_y) / std_y
+            X_mean = np.zeros(1)
+            X_std = np.ones(1)
+            y_mean = np.zeros(1)
+            y_std = np.ones(1)
 
         # Calculate matrix of distances D between samples
-        mzmax = n_samples * (n_samples - 1) / 2
-        ij = np.zeros([mzmax, 2])
-        D = np.zeros([mzmax, n_features])
-        ll = np.array([-1])
-        for k in range(n_samples-1):
-            ll = ll[-1] + 1 + range(n_samples - k - 1)
-            ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)],
-                                     [np.arange(k + 1, n_samples).T]]).T
-            D[ll] = X[k] - X[(k + 1):n_samples]
-        if np.min(np.sum(np.abs(D), 1)) == 0. \
+        D, ij = compute_componentwise_l1_cross_distances(X)
+        if np.min(np.sum(np.abs(D), axis=1)) == 0. \
                                     and self.corr != correlation.pure_nugget:
             raise Exception("Multiple X are not allowed")
 
@@ -273,8 +339,8 @@ class GaussianProcess(BaseEstimator):
         self.D = D
         self.ij = ij
         self.F = F
-        self.X_sc = np.concatenate([[mean_X], [std_X]])
-        self.y_sc = np.concatenate([[mean_y], [std_y]])
+        self.X_mean, self.X_std = X_mean, X_std
+        self.y_mean, self.y_std = y_mean, y_std
 
         # Determine Gaussian Process model parameters
         if self.thetaL is not None and self.thetaU is not None:
@@ -356,7 +422,7 @@ class GaussianProcess(BaseEstimator):
         # Run input checks
         self._check_params()
 
-        # Check design & trial sites
+        # Check input shapes
         X = np.atleast_2d(X)
         n_eval, n_features_X = X.shape
         n_samples, n_features = self.X.shape
@@ -370,31 +436,26 @@ class GaussianProcess(BaseEstimator):
             # No memory management
             # (evaluates all given points in a single batch run)
 
-            # Normalize trial sites
-            X = (X - self.X_sc[0][:]) / self.X_sc[1][:]
+            # Normalize input
+            X = (X - self.X_mean) / self.X_std
 
             # Initialize output
             y = np.zeros(n_eval)
             if eval_MSE:
                 MSE = np.zeros(n_eval)
 
-            # Get distances to design sites
-            dx = np.zeros([n_eval * n_samples, n_features])
-            kk = np.arange(n_samples).astype(int)
-            for k in range(n_eval):
-                dx[kk] = X[k] - self.X
-                kk = kk + n_samples
+            # Get pairwise componentwise L1-distances to the input training set
+            dx = compute_componentwise_l1_pairwise_distances(X, self.X)
 
             # Get regression function and correlation
             f = self.regr(X)
             r = self.corr(self.theta, dx).reshape(n_eval, n_samples)
 
             # Scaled predictor
-            y_ = np.dot(f, self.beta) \
-               + np.dot(r, self.gamma)
+            y_ = np.dot(f, self.beta) + np.dot(r, self.gamma)
 
             # Predictor
-            y = (self.y_sc[0] + self.y_sc[1] * y_).ravel()
+            y = (self.y_mean + self.y_std * y_).ravel()
 
             # Mean Squared Error
             if eval_MSE:
@@ -412,7 +473,7 @@ class GaussianProcess(BaseEstimator):
                     self.G = par['G']
 
                 rt = solve_triangular(C, r.T, lower=True)
-                
+
                 if self.beta0 is None:
                     # Universal Kriging
                     u = solve_triangular(self.G.T,
@@ -518,21 +579,8 @@ class GaussianProcess(BaseEstimator):
 
         if D is None:
             # Light storage mode (need to recompute D, ij and F)
-            if self.X.ndim > 1:
-                n_features = self.X.shape[1]
-            else:
-                n_features = 1
-            mzmax = n_samples * (n_samples - 1) / 2
-            ij = np.zeros([mzmax, n_features])
-            D = np.zeros([mzmax, n_features])
-            ll = np.array([-1])
-            for k in range(n_samples-1):
-                ll = ll[-1] + 1 + range(n_samples - k - 1)
-                ij[ll] = \
-                    np.concatenate([[np.repeat(k, n_samples - k - 1, 0)],
-                                    [np.arange(k + 1, n_samples).T]]).T
-                D[ll] = self.X[k] - self.X[(k + 1):n_samples]
-            if min(sum(abs(D), 1)) == 0. \
+            D, ij = compute_componentwise_l1_cross_distances(X)
+            if np.min(np.sum(np.abs(D), axis=1)) == 0. \
                                     and self.corr != correlation.pure_nugget:
                 raise Exception("Multiple X are not allowed")
             F = self.regr(self.X)
@@ -540,8 +588,8 @@ class GaussianProcess(BaseEstimator):
         # Set up R
         r = self.corr(theta, D)
         R = np.eye(n_samples) * (1. + self.nugget)
-        R[ij.astype(int)[:, 0], ij.astype(int)[:, 1]] = r
-        R[ij.astype(int)[:, 1], ij.astype(int)[:, 0]] = r
+        R[ij[:, 0], ij[:, 1]] = r
+        R[ij[:, 1], ij[:, 0]] = r
 
         # Cholesky decomposition of R
         try:
@@ -590,7 +638,7 @@ class GaussianProcess(BaseEstimator):
 
         # Compute/Organize output
         reduced_likelihood_function_value = - sigma2.sum() * detR
-        par['sigma2'] = sigma2 * self.y_sc[1] ** 2.
+        par['sigma2'] = sigma2 * self.y_std ** 2.
         par['beta'] = beta
         par['gamma'] = solve_triangular(C.T, rho)
         par['C'] = C
@@ -641,8 +689,9 @@ class GaussianProcess(BaseEstimator):
 
         if self.optimizer == 'fmin_cobyla':
 
-            minus_reduced_likelihood_function = lambda log10t: \
-                - self.reduced_likelihood_function(theta=10. ** log10t)[0]
+            def minus_reduced_likelihood_function(log10t):
+                return - self.reduced_likelihood_function(theta=10.
+                                                                  ** log10t)[0]
 
             constraints = []
             for i in range(self.theta0.size):
@@ -687,7 +736,7 @@ class GaussianProcess(BaseEstimator):
                 if self.verbose and self.random_start > 1:
                     if (20 * k) / self.random_start > percent_completed:
                         percent_completed = (20 * k) / self.random_start
-                        print str(5 * percent_completed) + "% completed"
+                        print "%s completed" % (5 * percent_completed)
 
             optimal_rlf_value = best_optimal_rlf_value
             optimal_par = best_optimal_par
@@ -723,11 +772,14 @@ class GaussianProcess(BaseEstimator):
                 self.theta0 = np.atleast_2d(theta_iso)
                 self.thetaL = np.atleast_2d(thetaL[0, i])
                 self.thetaU = np.atleast_2d(thetaU[0, i])
-                self.corr = lambda t, d: \
-                    corr(np.atleast_2d(np.hstack([
+
+                def corr_cut(t, d):
+                    return corr(np.atleast_2d(np.hstack([
                          optimal_theta[0][0:i],
                          t[0],
                          optimal_theta[0][(i + 1)::]])), d)
+
+                self.corr = corr_cut
                 optimal_theta[0, i], optimal_rlf_value, optimal_par = \
                     self.arg_max_reduced_likelihood_function()
 
@@ -745,28 +797,6 @@ class GaussianProcess(BaseEstimator):
 
         return optimal_theta, optimal_rlf_value, optimal_par
 
-    def score(self, X_test, y_test):
-        """
-        This score function returns the mean deviation of the Gaussian Process
-        model evaluated onto a test dataset.
-
-        Parameters
-        ----------
-        X_test : array_like
-            The feature test dataset with shape (n_tests, n_features).
-
-        y_test : array_like
-            The target test dataset (n_tests, ).
-
-        Returns
-        -------
-        score_value : array_like
-            The mean of the deviations between the prediction and the targets:
-            mean(y_pred - y_test).
-        """
-
-        return (self.predict(X_test, eval_MSE=False) - y_test).mean()
-
     def _check_params(self):
 
         # Check regression model
diff --git a/scikits/learn/gaussian_process/regression_models.py b/scikits/learn/gaussian_process/regression_models.py
index 4ba931b8e3a59b66cab678a647da42df41d22c9f..ea0eda50677b04a39abd65424aa105e30a11826b 100644
--- a/scikits/learn/gaussian_process/regression_models.py
+++ b/scikits/learn/gaussian_process/regression_models.py
@@ -31,11 +31,9 @@ def constant(x):
         An array with shape (n_eval, p) with the values of the regression
         model.
     """
-
     x = np.asanyarray(x, dtype=np.float)
     n_eval = x.shape[0]
     f = np.ones([n_eval, 1])
-
     return f
 
 
@@ -57,11 +55,9 @@ def linear(x):
         An array with shape (n_eval, p) with the values of the regression
         model.
     """
-
     x = np.asanyarray(x, dtype=np.float)
     n_eval = x.shape[0]
     f = np.hstack([np.ones([n_eval, 1]), x])
-
     return f
 
 
@@ -88,7 +84,7 @@ def quadratic(x):
     x = np.asanyarray(x, dtype=np.float)
     n_eval, n_features = x.shape
     f = np.hstack([np.ones([n_eval, 1]), x])
-    for  k in range(n_features):
+    for k in range(n_features):
         f = np.hstack([f, x[:, k, np.newaxis] * x[:, k:]])
 
     return f
diff --git a/scikits/learn/gaussian_process/tests/test_gaussian_process.py b/scikits/learn/gaussian_process/tests/test_gaussian_process.py
index 2f71070fa999dbd11f61ee9a5fcf012265b58d10..c285022769e272960388abb71e34c53cf898cf16 100644
--- a/scikits/learn/gaussian_process/tests/test_gaussian_process.py
+++ b/scikits/learn/gaussian_process/tests/test_gaussian_process.py
@@ -20,7 +20,6 @@ def test_1d(regr=regression.constant, corr=correlation.squared_exponential,
 
     Test the interpolating property.
     """
-
     f = lambda x: x * np.sin(x)
     X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
     y = f(X).ravel()
@@ -40,7 +39,6 @@ def test_2d(regr=regression.constant, corr=correlation.squared_exponential,
 
     Test the interpolating property.
     """
-
     b, kappa, e = 5., .5, .1
     g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2.
     X = np.array([[-4.61611719, -6.00099547],
@@ -80,7 +78,6 @@ def test_ordinary_kriging():
     Repeat test_1d and test_2d with given regression weights (beta0) for
     different regression models (Ordinary Kriging).
     """
-
     test_1d(regr='linear', beta0=[0., 0.5])
     test_1d(regr='quadratic', beta0=[0., 0.5, 0.5])
     test_2d(regr='linear', beta0=[0., 0.5, 0.5])