diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
index b8943a34136c069af15a7f978c5be45c8be1707e..99ea5ede7239672312bc77e9fa50ff11779747c2 100644
--- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
+++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
@@ -6,11 +6,15 @@
 Gaussian Processes for Machine Learning example
 ===============================================
 
-A two-dimensional regression exercise with a post-processing
-allowing for probabilistic classification thanks to the
-Gaussian property of the prediction.
+A two-dimensional regression exercise with a post-processing allowing for
+probabilistic classification thanks to the Gaussian property of the prediction.
+
+The figure illustrates the probability that the prediction is negative with
+respect to the remaining uncertainty in the prediction. The red and blue lines
+corresponds to the 95% confidence interval on the prediction of the zero level
+set.
 """
-print __doc__
+
 # Author: Vincent Dubourg <vincent.dubourg@gmail.com
 # License: BSD style
 
@@ -19,10 +23,9 @@ from scipy import stats
 from scikits.learn.gaussian_process import GaussianProcess
 from matplotlib import pyplot as pl
 from matplotlib import cm
-from mpl_toolkits.mplot3d import Axes3D
-class FormatFaker(object):
-    def __init__(self, str): self.str = str
-    def __mod__(self, stuff): return self.str
+
+# Print the docstring
+print __doc__
 
 # Standard normal distribution functions
 Grv = stats.distributions.norm()
@@ -34,18 +37,19 @@ PHIinv = lambda x: Grv.ppf(x)
 beta0 = 8
 b, kappa, e = 5, .5, .1
 
-# The function to predict (classification will then consist in predicting wheter g(x) <= 0 or not)
-g = lambda x: b - x[:,1] - kappa * ( x[:,0] - e )**2.
+# The function to predict (classification will then consist in predicting
+# wheter g(x) <= 0 or not)
+g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2.
 
 # Design of experiments
 X = np.array([[-4.61611719, -6.00099547],
-              [ 4.10469096,  5.32782448],
-              [ 0.00000000, -0.50000000],
-              [-6.17289014, -4.6984743 ],
-              [ 1.3109306 , -6.93271427],
-              [-5.03823144,  3.10584743],
-              [-2.87600388,  6.74310541],
-              [ 5.21301203,  4.26386883]])
+              [4.10469096, 5.32782448],
+              [0.00000000, -0.50000000],
+              [-6.17289014, -4.6984743],
+              [1.3109306, -6.93271427],
+              [-5.03823144, 3.10584743],
+              [-2.87600388, 6.74310541],
+              [5.21301203, 4.26386883]])
 
 # Observations
 Y = g(X)
@@ -58,41 +62,53 @@ gp.fit(X, Y)
 
 # Evaluate real function, the prediction and its MSE on a grid
 res = 50
-x1, x2 = np.meshgrid(np.linspace(-beta0, beta0, res), np.linspace(-beta0, beta0, res))
+x1, x2 = np.meshgrid(np.linspace(- beta0, beta0, res), \
+                     np.linspace(- beta0, beta0, res))
 xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T
 
 YY = g(xx)
 yy, MSE = gp.predict(xx, eval_MSE=True)
 sigma = np.sqrt(MSE)
-yy = yy.reshape((res,res))
-YY = YY.reshape((res,res))
-sigma = sigma.reshape((res,res))
+yy = yy.reshape((res, res))
+YY = YY.reshape((res, res))
+sigma = sigma.reshape((res, res))
 k = PHIinv(.975)
 
-# Plot the probabilistic classification iso-values using the Gaussian property of the prediction
+# Plot the probabilistic classification iso-values using the Gaussian property
+# of the prediction
 fig = pl.figure(1)
-
 ax = fig.add_subplot(111)
 ax.axes.set_aspect('equal')
-cax = pl.imshow(np.flipud(PHI(-yy/sigma)), cmap=cm.gray_r, alpha=.8, extent=(-beta0,beta0,-beta0,beta0))
-norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9)
-cb = pl.colorbar(cax, ticks=[0.,0.2,0.4,0.6,0.8,1.], norm=norm)
-cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$')
-pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12)
-pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12)
 pl.xticks([])
 pl.yticks([])
 ax.set_xticklabels([])
 ax.set_yticklabels([])
 pl.xlabel('$x_1$')
 pl.ylabel('$x_2$')
-cs = pl.contour(x1, x2, YY, [0.], colors='k', linestyles='dashdot')
-pl.clabel(cs,fmt=FormatFaker(u'$g(\mathbf{x})=0$'),fontsize=11)
-cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', linestyles='solid')
-pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$'),fontsize=11)
-cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', linestyles='dashed')
-pl.clabel(cs,fmt=FormatFaker(u'$\mu_{\widehat{G}}(\mathbf{x})=0$'),fontsize=11)
-cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', linestyles='solid')
-pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$'),fontsize=11)
-
-pl.show()
\ No newline at end of file
+
+cax = pl.imshow(np.flipud(PHI(- yy / sigma)), cmap=cm.gray_r, alpha=0.8, \
+                extent=(- beta0, beta0, - beta0, beta0))
+norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9)
+cb = pl.colorbar(cax, ticks=[0., 0.2, 0.4, 0.6, 0.8, 1.], norm=norm)
+cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$')
+
+pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12)
+
+pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12)
+
+cs = pl.contour(x1, x2, YY, [0.], colors='k', \
+                linestyles='dashdot')
+
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', \
+                linestyles='solid')
+pl.clabel(cs, fontsize=11)
+
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', \
+                linestyles='dashed')
+pl.clabel(cs, fontsize=11)
+
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', \
+                linestyles='solid')
+pl.clabel(cs, fontsize=11)
+
+pl.show()
diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py
index e62207499c0330792f4f8b0a25f618811f61b848..80b5377231bca67ad05f1d3133ea3d473a386091 100644
--- a/examples/gaussian_process/plot_gp_regression.py
+++ b/examples/gaussian_process/plot_gp_regression.py
@@ -6,21 +6,26 @@
 Gaussian Processes for Machine Learning example
 ===============================================
 
-A simple one-dimensional regression exercise with a
-cubic correlation model whose parameters are estimated
-using the maximum likelihood principle.
+A simple one-dimensional regression exercise with a cubic correlation
+model whose parameters are estimated using the maximum likelihood principle.
+
+The figure illustrates the interpolating property of the Gaussian Process
+model as well as its probabilistic nature in the form of a pointwise 95%
+confidence interval.
 """
-print __doc__
+
 # Author: Vincent Dubourg <vincent.dubourg@gmail.com
 # License: BSD style
 
 import numpy as np
-from scipy import stats
 from scikits.learn.gaussian_process import GaussianProcess, corrcubic
 from matplotlib import pyplot as pl
 
+# Print the docstring
+print __doc__
+
 # The function to predict
-f = lambda x: x*np.sin(x)
+f = lambda x: x * np.sin(x)
 
 # The design of experiments
 X = np.array([1., 3., 5., 6., 7., 8.])
@@ -28,11 +33,13 @@ X = np.array([1., 3., 5., 6., 7., 8.])
 # Observations
 Y = f(X)
 
-# Mesh the input space for evaluations of the real function, the prediction and its MSE
-x = np.linspace(0,10,1000)
+# Mesh the input space for evaluations of the real function, the prediction and
+# its MSE
+x = np.linspace(0, 10, 1000)
 
 # Instanciate a Gaussian Process model
-gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=100)
+gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \
+                     random_start=100)
 
 # Fit to data using Maximum Likelihood Estimation of the parameters
 gp.fit(X, Y)
@@ -41,12 +48,15 @@ gp.fit(X, Y)
 y, MSE = gp.predict(x, eval_MSE=True)
 sigma = np.sqrt(MSE)
 
-# Plot the function, the prediction and the 95% confidence interval based on the MSE
+# Plot the function, the prediction and the 95% confidence interval based on
+# the MSE
 fig = pl.figure()
 pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
 pl.plot(X, Y, 'r.', markersize=10, label=u'Observations')
 pl.plot(x, y, 'b-', label=u'Prediction')
-pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label='95% confidence interval')
+pl.fill(np.concatenate([x, x[::-1]]), \
+        np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), \
+        alpha=.5, fc='b', ec='None', label='95% confidence interval')
 pl.xlabel('$x$')
 pl.ylabel('$f(x)$')
 pl.ylim(-10, 20)
diff --git a/scikits/learn/gaussian_process/__init__.py b/scikits/learn/gaussian_process/__init__.py
index 51cb685e185b200aaa9fcaf3af599ec2451a9b42..f2bb68562f77ed4ec5b2bf93f6dec1b357b24703 100644
--- a/scikits/learn/gaussian_process/__init__.py
+++ b/scikits/learn/gaussian_process/__init__.py
@@ -4,11 +4,12 @@
 """
     This module contains a contribution to the scikit-learn project that
     implements Gaussian Process based prediction (also known as Kriging).
-    
+
     The present implementation is based on a transliteration of the DACE
     Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>.
 """
 
 from .gaussian_process import GaussianProcess
-from .correlation import *
-from .regression import *
\ No newline at end of file
+from .correlation import corrlin, corrcubic, correxp1, \
+                         correxp2, correxpg, corriid
+from .regression import regpoly0, regpoly1, regpoly2
diff --git a/scikits/learn/gaussian_process/correlation.py b/scikits/learn/gaussian_process/correlation.py
index c1d8235f22bf695cdf8a28e6252b609f19afa525..65dcc5b825b2ed21c144b0a407253c02afdfaa18 100644
--- a/scikits/learn/gaussian_process/correlation.py
+++ b/scikits/learn/gaussian_process/correlation.py
@@ -7,33 +7,42 @@
 
 import numpy as np
 
+
 #############################
 # Defaut correlation models #
 #############################
 
-def correxp1(theta,d):
+
+def correxp1(theta, d):
     """
     Exponential autocorrelation model.
     (Ornstein-Uhlenbeck stochastic process)
+
                                                    n
     correxp1 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i| )
                                                  i = 1
+
     Parameters
     ----------
     theta : array_like
-        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+        An array with shape 1 (isotropic) or n (anisotropic) giving the
+        autocorrelation parameter(s).
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) containing the values of the autocorrelation model.
+        An array with shape (n_eval, ) containing the values of the
+        autocorrelation model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
@@ -41,36 +50,44 @@ def correxp1(theta,d):
     if theta.size == 1:
         theta = np.repeat(theta, n_features)
     elif theta.size != n_features:
-        raise ValueError, "Length of theta must be 1 or "+str(n_features)
-    
+        raise ValueError("Length of theta must be 1 or " + str(n_features))
+
     td = - theta.reshape(1, n_features) * abs(d)
-    r = np.exp(np.sum(td,1))
-    
+    r = np.exp(np.sum(td, 1))
+
     return r
 
-def correxp2(theta,d):
+
+def correxp2(theta, d):
     """
     Squared exponential correlation model (Radial Basis Function).
     (Infinitely differentiable stochastic process, very smooth)
+
                                                    n
     correxp2 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * (dx_i)^2 )
                                                  i = 1
+
     Parameters
     ----------
     theta : array_like
-        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+        An array with shape 1 (isotropic) or n (anisotropic) giving the
+        autocorrelation parameter(s).
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) containing the values of the autocorrelation model.
+        An array with shape (n_eval, ) containing the values of the
+        autocorrelation model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
@@ -78,36 +95,45 @@ def correxp2(theta,d):
     if theta.size == 1:
         theta = np.repeat(theta, n_features)
     elif theta.size != n_features:
-        raise Exception, "Length of theta must be 1 or "+str(n_features)
-    
-    td = - theta.reshape(1, n_features) * d**2
-    r = np.exp(np.sum(td,1))
-    
+        raise Exception("Length of theta must be 1 or " + str(n_features))
+
+    td = - theta.reshape(1, n_features) * d ** 2
+    r = np.exp(np.sum(td, 1))
+
     return r
 
-def correxpg(theta,d):
+
+def correxpg(theta, d):
     """
     Generalized exponential correlation model.
-    (Useful when one does not know the smoothness of the function to be predicted.)
+    (Useful when one does not know the smoothness of the function to be
+    predicted.)
+
                                                    n
     correxpg : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i|^p )
                                                  i = 1
+
     Parameters
     ----------
     theta : array_like
-        An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the autocorrelation parameter(s) (theta, p).
+        An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the
+        autocorrelation parameter(s) (theta, p).
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) with the values of the autocorrelation model.
+        An array with shape (n_eval, ) with the values of the autocorrelation
+        model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
@@ -115,125 +141,151 @@ def correxpg(theta,d):
     lth = theta.size
     if n_features > 1 and lth == 2:
         theta = np.hstack([np.repeat(theta[0], n_features), theta[1]])
-    elif lth != n_features+1:
-        raise Exception, "Length of theta must be 2 or "+str(n_features+1)
+    elif lth != n_features + 1:
+        raise Exception("Length of theta must be 2 or " + str(n_features + 1))
     else:
         theta = theta.reshape(1, lth)
-    
-    td = - theta[:,0:-1].reshape(1, n_features) * abs(d)**theta[:,-1]
-    r = np.exp(np.sum(td,1))
-    
+
+    td = - theta[:, 0:-1].reshape(1, n_features) * abs(d) ** theta[:, -1]
+    r = np.exp(np.sum(td, 1))
+
     return r
 
-def corriid(theta,d):
+
+def corriid(theta, d):
     """
     Spatial independence correlation model (pure nugget).
     (Useful when one wants to solve an ordinary least squares problem!)
+
                                                     n
-    corriid : theta, dx --> r(theta, dx) = 1 if   sum |dx_i| == 0 and 0 otherwise
+    corriid : theta, dx --> r(theta, dx) = 1 if   sum |dx_i| == 0
                                                   i = 1
+                                           0 otherwise
+
     Parameters
     ----------
     theta : array_like
         None.
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) with the values of the autocorrelation model.
+        An array with shape (n_eval, ) with the values of the autocorrelation
+        model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     n_eval = d.shape[0]
     r = np.zeros(n_eval)
     # The ones on the diagonal of the correlation matrix are enforced within
     # the KrigingModel instanciation to allow multiple design sites in this
     # ordinary least squares context.
-    
+
     return r
 
-def corrcubic(theta,d):
+
+def corrcubic(theta, d):
     """
     Cubic correlation model.
-                                               n
-    corrcubic : theta, dx --> r(theta, dx) = prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) ,  i = 1,...,m
-                                             j = 1
+
+    corrcubic : theta, dx --> r(theta, dx) =
+          n
+        prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) ,  i = 1,...,m
+        j = 1
+
     Parameters
     ----------
     theta : array_like
-        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+        An array with shape 1 (isotropic) or n (anisotropic) giving the
+        autocorrelation parameter(s).
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) with the values of the autocorrelation model.
+        An array with shape (n_eval, ) with the values of the autocorrelation
+        model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
         n_features = 1
     lth = theta.size
     if  lth == 1:
-        theta = np.repeat(theta, n_features)[np.newaxis,:]
+        theta = np.repeat(theta, n_features)[np.newaxis][:]
     elif lth != n_features:
-        raise Exception, "Length of theta must be 1 or "+str(n_features)
+        raise Exception("Length of theta must be 1 or " + str(n_features))
     else:
         theta = theta.reshape(1, n_features)
-    
+
     td = abs(d) * theta
     td[td > 1.] = 1.
-    ss = 1. - td**2. * (3. - 2.*td)
-    r = np.prod(ss,1)
-    
+    ss = 1. - td ** 2. * (3. - 2. * td)
+    r = np.prod(ss, 1)
+
     return r
 
-def corrlin(theta,d):
+
+def corrlin(theta, d):
     """
     Linear correlation model.
-                                             n
-    corrlin : theta, dx --> r(theta, dx) = prod max(0, 1 - theta_j*d_ij) ,  i = 1,...,m
-                                           j = 1
+
+    corrlin : theta, dx --> r(theta, dx) =
+          n
+        prod max(0, 1 - theta_j*d_ij) ,  i = 1,...,m
+        j = 1
+
     Parameters
     ----------
     theta : array_like
-        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+        An array with shape 1 (isotropic) or n (anisotropic) giving the
+        autocorrelation parameter(s).
+
     dx : array_like
-        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the componentwise
+        distances between locations x and x' at which the correlation model
+        should be evaluated.
+
     Returns
     -------
     r : array_like
-        An array with shape (n_eval, ) with the values of the autocorrelation model.
+        An array with shape (n_eval, ) with the values of the autocorrelation
+        model.
     """
-    
+
     theta = np.asanyarray(theta, dtype=np.float)
     d = np.asanyarray(d, dtype=np.float)
-    
+
     if d.ndim > 1:
         n_features = d.shape[1]
     else:
         n_features = 1
     lth = theta.size
     if  lth == 1:
-        theta = np.repeat(theta, n_features)[np.newaxis,:]
+        theta = np.repeat(theta, n_features)[np.newaxis][:]
     elif lth != n_features:
-        raise Exception, "Length of theta must be 1 or "+str(n_features)
+        raise Exception("Length of theta must be 1 or " + str(n_features))
     else:
         theta = theta.reshape(1, n_features)
-    
+
     td = abs(d) * theta
     td[td > 1.] = 1.
     ss = 1. - td
-    r = np.prod(ss,1)
-        
-    return r
\ No newline at end of file
+    r = np.prod(ss, 1)
+
+    return r
diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py
index 0d07a05ceaa4bfb9638eeb687cdc6ea806d6d1a4..f06bfc50be7e08404285e223659a0ed464262ae3 100644
--- a/scikits/learn/gaussian_process/gaussian_process.py
+++ b/scikits/learn/gaussian_process/gaussian_process.py
@@ -6,43 +6,48 @@
 ################
 
 import numpy as np
-from scipy import linalg, optimize, random
+from scipy import linalg, optimize, rand
 from ..base import BaseEstimator
 from .regression import regpoly0
-from .correlation import correxp2
+from .correlation import correxp2, corriid
 machine_epsilon = np.finfo(np.double).eps
 
+
 def col_sum(x):
     """
     Performs columnwise sums of elements in x depending on its shape.
-    
+
     Parameters
     ----------
     x : array_like
-        An array of size (p, q).
-    
+        An array with shape size (p, q).
+
     Returns
     -------
     s : array_like
-        An array of size (q, ) which contains the columnwise sums of the elements in x.
+        An array whit shape (q, ) which contains the columnwise sums of the
+        elements in x.
     """
-    
+
     x = np.asanyarray(x, dtype=np.float)
-    
+
     if x.ndim > 1:
         s = x.sum(axis=0)
     else:
         s = x
-    
+
     return s
 
+
 ##############################
 # The Gaussian Process class #
 ##############################
 
+
 class GaussianProcess(BaseEstimator):
     """
-    A class that implements scalar Gaussian Process based prediction (also known as Kriging).
+    A class that implements scalar Gaussian Process based prediction (also
+    known as Kriging).
 
     Example
     -------
@@ -53,7 +58,8 @@ class GaussianProcess(BaseEstimator):
     f = lambda x: x*np.sin(x)
     X = np.array([1., 3., 5., 6., 7., 8.])
     Y = f(X)
-    gp = GaussianProcess(regr=regpoly0, corr=correxp2, theta0=1e-1, thetaL=1e-3, thetaU=1e0, random_start=100)
+    gp = GaussianProcess(regr=regpoly0, corr=correxp2, theta0=1e-1, \
+                         thetaL=1e-3, thetaU=1e0, random_start=100)
     gp.fit(X, Y)
 
     pl.figure(1)
@@ -63,17 +69,19 @@ class GaussianProcess(BaseEstimator):
     pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
     pl.plot(X, Y, 'r.', markersize=10, label=u'Observations')
     pl.plot(x, y, 'k-', label=u'$\widehat{f}(x) = {\\rm BLUP}(x)$')
-    pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label=u'95\% confidence interval')
+    pl.fill(np.concatenate([x, x[::-1]]), \
+            np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), \
+            alpha=.5, fc='b', ec='None', label=u'95\% confidence interval')
     pl.xlabel('$x$')
     pl.ylabel('$f(x)$')
     pl.legend(loc='upper left')
 
     pl.figure(2)
-    theta_values = np.logspace(np.log10(gp.thetaL),np.log10(gp.thetaU),100)
-    reduced_likelihood_function_values = []
+    theta_values = np.logspace(np.log10(gp.thetaL), np.log10(gp.thetaU),100)
+    psi_values = []
     for t in theta_values:
-            reduced_likelihood_function_values.append(gp.reduced_likelihood_function(theta=t)[0])
-    pl.plot(theta_values, reduced_likelihood_function_values)
+            psi_values.append(gp.reduced_likelihood_function(theta=t)[0])
+    pl.plot(theta_values, psi_values)
     pl.xlabel(u'$\\theta$')
     pl.ylabel(u'Score')
     pl.xscale('log')
@@ -90,43 +98,59 @@ class GaussianProcess(BaseEstimator):
 
     Todo
     ----
-    o Add the 'sparse' storage mode for which the correlation matrix is stored in its sparse eigen decomposition format instead of its full Cholesky decomposition.
+    o Add the 'sparse' storage mode for which the correlation matrix is stored
+      in its sparse eigen decomposition format instead of its full Cholesky
+      decomposition.
 
     Implementation details
     ----------------------
-    The presentation implementation is based on a translation of the DACE Matlab toolbox.
-    
+    The presentation implementation is based on a translation of the DACE
+    Matlab toolbox.
+
     See references:
-    [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). DACE - A MATLAB Kriging Toolbox.
+    [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002).
+        DACE - A MATLAB Kriging Toolbox.
         http://www2.imm.dtu.dk/~hbn/dace/dace.pdf
     """
-    
-    def __init__(self, regr = regpoly0, corr = correxp2, beta0 = None, storage_mode = 'full', verbose = True, theta0 = 1e-1, thetaL = None, thetaU = None, optimizer = 'fmin_cobyla', random_start = 1, normalize = True, nugget = 10.*machine_epsilon):
+
+    def __init__(self, regr=regpoly0, corr=correxp2, beta0=None, \
+                 storage_mode='full', verbose=True, theta0=1e-1, \
+                 thetaL=None, thetaU=None, optimizer='fmin_cobyla', \
+                 random_start=1, normalize=True, \
+                 nugget=10. * machine_epsilon):
         """
         The Gaussian Process model constructor.
 
         Parameters
         ----------
         regr : lambda function, optional
-            A regression function returning an array of outputs of the linear regression functional basis.
-            (The number of observations m should be greater than the size p of this basis)
+            A regression function returning an array of outputs of the linear
+            regression functional basis. The number of observations n_samples
+            should be greater than the size p of this basis.
             Default assumes a simple constant regression trend (see regpoly0).
-            
+
         corr : lambda function, optional
-            A stationary autocorrelation function returning the autocorrelation between two points x and x'.
-            Default assumes a squared-exponential autocorrelation model (see correxp2).
-            
+            A stationary autocorrelation function returning the autocorrelation
+            between two points x and x'.
+            Default assumes a squared-exponential autocorrelation model (see
+            correxp2).
+
         beta0 : double array_like, optional
-            Default assumes Universal Kriging (UK) so that the vector beta of regression weights is estimated
-            by Maximum Likelihood. Specifying beta0 overrides estimation and performs Ordinary Kriging (OK).
-            
+            The regression weight vector to perform Ordinary Kriging (OK).
+            Default assumes Universal Kriging (UK) so that the vector beta of
+            regression weights is estimated using the maximum likelihood
+            principle.
+
         storage_mode : string, optional
-            A string specifying whether the Cholesky decomposition of the correlation matrix should be
-            stored in the class (storage_mode = 'full') or not (storage_mode = 'light').
-            Default assumes storage_mode = 'full', so that the Cholesky decomposition of the correlation matrix is stored.
-            This might be a useful parameter when one is not interested in the MSE and only plan to estimate the BLUP,
-            for which the correlation matrix is not required.
-            
+            A string specifying whether the Cholesky decomposition of the
+            correlation matrix should be stored in the class (storage_mode =
+            'full') or not (storage_mode = 'light').
+            Default assumes storage_mode = 'full', so that the
+            Cholesky decomposition of the correlation matrix is stored.
+            This might be a useful parameter when one is not interested in the
+            MSE and only plan to estimate the BLUP, for which the correlation
+            matrix is not required.
+
         verbose : boolean, optional
             A boolean specifying the verbose level.
             Default is verbose = True.
@@ -134,37 +158,49 @@ class GaussianProcess(BaseEstimator):
         theta0 : double array_like, optional
             An array with shape (n_features, ) or (1, ).
             The parameters in the autocorrelation model.
-            If thetaL and thetaU are also specified, theta0 is considered as the starting point
-            for the Maximum Likelihood Estimation of the best set of parameters.
+            If thetaL and thetaU are also specified, theta0 is considered as
+            the starting point for the maximum likelihood rstimation of the
+            best set of parameters.
             Default assumes isotropic autocorrelation model with theta0 = 1e-1.
-            
+
         thetaL : double array_like, optional
             An array with shape matching theta0's.
-            Lower bound on the autocorrelation parameters for Maximum Likelihood Estimation.
-            Default is None, skips Maximum Likelihood Estimation and uses theta0.
-            
+            Lower bound on the autocorrelation parameters for maximum
+            likelihood estimation.
+            Default is None, so that it skips maximum likelihood estimation and
+            it uses theta0.
+
         thetaU : double array_like, optional
             An array with shape matching theta0's.
-            Upper bound on the autocorrelation parameters for Maximum Likelihood Estimation.
-            Default is None and skips Maximum Likelihood Estimation and uses theta0.
-            
+            Upper bound on the autocorrelation parameters for maximum
+            likelihood estimation.
+            Default is None, so that it skips maximum likelihood estimation and
+            it uses theta0.
+
         normalize : boolean, optional
-            Design sites X and observations y are centered and reduced wrt means and standard deviations
-            estimated from the n_samples observations provided.
-            Default is true so that data is normalized to ease MLE.
-            
+            Design sites X and observations y are centered and reduced wrt
+            means and standard deviations estimated from the n_samples
+            observations provided.
+            Default is normalize = True so that data is normalized to ease
+            maximum likelihood estimation.
+
         nugget : double, optional
-            Introduce a nugget effect to allow smooth predictions from noisy data.
-            Default assumes a nugget close to machine precision for the sake of robustness (nugget = 10.*machine_epsilon).
-            
+            Introduce a nugget effect to allow smooth predictions from noisy
+            data.
+            Default assumes a nugget close to machine precision for the sake of
+            robustness (nugget = 10.*machine_epsilon).
+
         optimizer : string, optional
-            A string specifying the optimization algorithm to be used ('fmin_cobyla' is the sole algorithm implemented yet).
+            A string specifying the optimization algorithm to be used
+            ('fmin_cobyla' is the sole algorithm implemented yet).
             Default uses 'fmin_cobyla' algorithm from scipy.optimize.
-            
+
         random_start : int, optional
-            The number of times the Maximum Likelihood Estimation should be performed from a random starting point.
+            The number of times the Maximum Likelihood Estimation should be
+            performed from a random starting point.
             The first MLE always uses the specified starting point (theta0),
-            the next starting points are picked at random according to an exponential distribution (log-uniform in [thetaL, thetaU]).
+            the next starting points are picked at random according to an
+            exponential distribution (log-uniform on [thetaL, thetaU]).
             Default does not use random starting point (random_start = 1).
 
         Returns
@@ -172,52 +208,57 @@ class GaussianProcess(BaseEstimator):
         gp : self
             A Gaussian Process model object awaiting data to be fitted to.
         """
-        
+
         self.regr = regr
         self.corr = corr
         self.beta0 = beta0
-        
+
         # Check storage mode
         if storage_mode != 'full' and storage_mode != 'light':
             if storage_mode == 'sparse':
-                raise NotImplementedError, "The 'sparse' storage mode is not supported yet. Please contribute!"
+                raise NotImplementedError("The 'sparse' storage mode is not " \
+                      + "supported yet. Please contribute!")
             else:
-                raise ValueError, "Storage mode should either be 'full' or 'light'. Unknown storage mode: "+str(storage_mode)
+                raise ValueError("Storage mode should either be 'full' or " \
+                      + "'light'. Unknown storage mode: " + str(storage_mode))
         else:
             self.storage_mode = storage_mode
-        
+
         self.verbose = verbose
-        
+
         # Check correlation parameters
         self.theta0 = np.atleast_2d(np.asanyarray(theta0, dtype=np.float))
         self.thetaL, self.thetaU = thetaL, thetaU
         lth = self.theta0.size
-        
+
         if self.thetaL is not None and self.thetaU is not None:
             # Parameters optimization case
             self.thetaL = np.atleast_2d(np.asanyarray(thetaL, dtype=np.float))
             self.thetaU = np.atleast_2d(np.asanyarray(thetaU, dtype=np.float))
-            
+
             if self.thetaL.size != lth or self.thetaU.size != lth:
-                raise ValueError, "theta0, thetaL and thetaU must have the same length"
+                raise ValueError("theta0, thetaL and thetaU must have the " \
+                      + "same length")
             if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL):
-                raise ValueError, "The bounds must satisfy O < thetaL <= thetaU"
-            
+                raise ValueError("The bounds must satisfy O < thetaL <= " \
+                      + "thetaU")
+
         elif self.thetaL is None and self.thetaU is None:
             # Given parameters case
             if np.any(self.theta0 <= 0):
-                raise ValueError, "theta0 must be strictly positive"
-            
+                raise ValueError("theta0 must be strictly positive")
+
         elif self.thetaL is None or self.thetaU is None:
             # Error
-            raise Exception, "thetaL and thetaU should either be both or neither specified"
-        
+            raise Exception("thetaL and thetaU should either be both or " \
+                  + "neither specified")
+
         # Store other parameters
         self.normalize = normalize
         self.nugget = nugget
         self.optimizer = optimizer
         self.random_start = int(random_start)
-        
+
     def fit(self, X, y):
         """
         The Gaussian Process model fitting method.
@@ -225,21 +266,24 @@ class GaussianProcess(BaseEstimator):
         Parameters
         ----------
         X : double array_like
-            An array with shape (n_samples, n_features) with the design sites at which observations were made.
-            
+            An array with shape (n_samples, n_features) with the design sites
+            at which observations were made.
+
         y : double array_like
-            An array with shape (n_features, ) with the observations of the scalar output to be predicted.
-            
+            An array with shape (n_features, ) with the observations of the
+            scalar output to be predicted.
+
         Returns
         -------
         gp : self
-            A fitted Gaussian Process model object awaiting data to perform predictions.
+            A fitted Gaussian Process model object awaiting data to perform
+            predictions.
         """
-        
+
         # Force data to numpy.array type (from coding guidelines)
         X = np.asanyarray(X, dtype=np.float)
         y = np.asanyarray(y, dtype=np.float)
-        
+
         # Check design sites & observations
         n_samples_X = X.shape[0]
         if X.ndim > 1:
@@ -247,23 +291,25 @@ class GaussianProcess(BaseEstimator):
         else:
             n_features = 1
         X = X.reshape(n_samples_X, n_features)
-        
+
         n_samples_y = y.shape[0]
         if y.ndim > 1:
-            raise NotImplementedError, "y has more than one dimension. This is not supported yet (scalar output prediction only). Please contribute!"
+            raise NotImplementedError("y has more than one dimension. This " \
+                  + "is not supported yet (scalar output prediction only). " \
+                  + "Please contribute!")
         y = y.reshape(n_samples_y, 1)
-        
+
         if n_samples_X != n_samples_y:
-            raise Exception, "X and y must have the same number of rows!"
+            raise Exception("X and y must have the same number of rows!")
         else:
             n_samples = n_samples_X
-        
+
         # Normalize data or don't
         if self.normalize:
             mean_X = np.mean(X, axis=0)
-            std_X = np.sqrt(1./(n_samples-1.)*np.sum((X - np.mean(X,0))**2.,0)) #np.std(y, axis=0)
+            std_X = np.std(X, axis=0)
             mean_y = np.mean(y, axis=0)
-            std_y = np.sqrt(1./(n_samples-1.)*np.sum((y - np.mean(y,0))**2.,0)) #np.std(y, axis=0)
+            std_y = np.std(y, axis=0)
             std_X[std_X == 0.] = 1.
             std_y[std_y == 0.] = 1.
         else:
@@ -271,22 +317,23 @@ class GaussianProcess(BaseEstimator):
             std_X = np.array([1.])
             mean_y = np.array([0.])
             std_y = np.array([1.])
-        
+
         X_ = (X - mean_X) / std_X
         y_ = (y - mean_y) / std_y
-        
+
         # Calculate matrix of distances D between samples
-        mzmax = n_samples*(n_samples-1)/2
+        mzmax = n_samples * (n_samples - 1) / 2
         ij = np.zeros([mzmax, 2])
         D = np.zeros([mzmax, n_features])
         ll = np.array([-1])
         for k in range(n_samples-1):
-            ll = ll[-1] + 1 + range(n_samples-k-1)
-            ij[ll,:] = np.concatenate([[np.repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T
-            D[ll,:] = X_[k,:] - X_[(k+1):n_samples,:]
-        if (np.min(np.sum(np.abs(D),1)) == 0.) and (self.corr != corriid):
-            raise Exception, "Multiple X are not allowed"
-        
+            ll = ll[-1] + 1 + range(n_samples - k - 1)
+            ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \
+                                     [np.arange(k + 1, n_samples).T]]).T
+            D[ll] = X_[k] - X_[(k + 1):n_samples]
+        if np.min(np.sum(np.abs(D), 1)) == 0. and self.corr != corriid:
+            raise Exception("Multiple X are not allowed")
+
         # Regression matrix and parameters
         F = self.regr(X)
         n_samples_F = F.shape[0]
@@ -295,12 +342,17 @@ class GaussianProcess(BaseEstimator):
         else:
             p = 1
         if n_samples_F != n_samples:
-            raise Exception, "Number of rows in F and X do not match. Most likely something is wrong in the regression model."
+            raise Exception("Number of rows in F and X do not match. Most " \
+                          + "likely something is going wrong with the " \
+                          + "regression model.")
         if p > n_samples_F:
-            raise Exception, "Ordinary least squares problem is undetermined. n_samples=%d must be greater than the regression model size p=%d!" % (n_samples, p)
-        if self.beta0 is not None and (self.beta0.shape[0] != p or self.beta0.ndim > 1):
-            raise Exception, "Shapes of beta0 and F do not match."
-    
+            raise Exception("Ordinary least squares problem is undetermined " \
+                           + "n_samples=%d must be greater than the " \
+                           + "regression model size p=%d!" % (n_samples, p))
+        if self.beta0 is not None \
+           and (self.beta0.shape[0] != p or self.beta0.ndim > 1):
+            raise Exception("Shapes of beta0 and F do not match.")
+
         # Set attributes
         self.X = X
         self.y = y
@@ -309,365 +361,453 @@ class GaussianProcess(BaseEstimator):
         self.D = D
         self.ij = ij
         self.F = F
-        self.X_sc = np.concatenate([[mean_X],[std_X]])
-        self.y_sc = np.concatenate([[mean_y],[std_y]])
-        
+        self.X_sc = np.concatenate([[mean_X], [std_X]])
+        self.y_sc = np.concatenate([[mean_y], [std_y]])
+
         # Determine Gaussian Process model parameters
         if self.thetaL is not None and self.thetaU is not None:
             # Maximum Likelihood Estimation of the parameters
             if self.verbose:
-                print "Performing Maximum Likelihood Estimation of the autocorrelation parameters..."
-            self.theta, self.reduced_likelihood_function_value, self.par  = self.arg_max_reduced_likelihood_function()
-            if np.isinf(self.reduced_likelihood_function_value) :
-                raise Exception, "Bad parameter region. Try increasing upper bound"
+                print "Performing Maximum Likelihood Estimation of the " \
+                    + "autocorrelation parameters..."
+            self.theta, self.reduced_likelihood_function_value, self.par = \
+                self.arg_max_reduced_likelihood_function()
+            if np.isinf(self.reduced_likelihood_function_value):
+                raise Exception("Bad parameter region. " \
+                              + "Try increasing upper bound")
+
         else:
             # Given parameters
             if self.verbose:
-                print "Given autocorrelation parameters. Computing Gaussian Process model parameters..."
+                print "Given autocorrelation parameters. " \
+                    + "Computing Gaussian Process model parameters..."
             self.theta = self.theta0
-            self.reduced_likelihood_function_value, self.par = self.reduced_likelihood_function()
+            self.reduced_likelihood_function_value, self.par = \
+                self.reduced_likelihood_function()
             if np.isinf(self.reduced_likelihood_function_value):
-                raise Exception, "Bad point. Try increasing theta0"
-        
+                raise Exception("Bad point. Try increasing theta0")
+
         if self.storage_mode == 'light':
             # Delete heavy data (it will be computed again if required)
             # (it is required only when MSE is wanted in self.predict)
             if self.verbose:
-                print "Light storage mode specified. Flushing autocorrelation matrix..."
+                print "Light storage mode specified. " \
+                    + "Flushing autocorrelation matrix..."
             self.D = None
             self.ij = None
             self.F = None
             self.par['C'] = None
             self.par['Ft'] = None
             self.par['G'] = None
-        
+
         return self
-        
+
     def predict(self, X, eval_MSE=False, batch_size=None):
         """
         This function evaluates the Gaussian Process model at x.
-        
+
         Parameters
         ----------
         X : array_like
-            An array with shape (n_eval, n_features) giving the point(s) at which the prediction(s) should be made.
+            An array with shape (n_eval, n_features) giving the point(s) at
+            which the prediction(s) should be made.
+
         eval_MSE : boolean, optional
-            A boolean specifying whether the Mean Squared Error should be evaluated or not.
-            Default assumes evalMSE = False and evaluates only the BLUP (mean prediction).
+            A boolean specifying whether the Mean Squared Error should be
+            evaluated or not.
+            Default assumes evalMSE = False and evaluates only the BLUP (mean
+            prediction).
+
         verbose : boolean, optional
             A boolean specifying the verbose level.
             Default is verbose = True.
+
         batch_size : integer, optional
-            An integer giving the maximum number of points that can be evaluated simulatneously (depending on the available memory).
-            Default is None so that all given points are evaluated at the same time.
-        
+            An integer giving the maximum number of points that can be
+            evaluated simulatneously (depending on the available memory).
+            Default is None so that all given points are evaluated at the same
+            time.
+
         Returns
         -------
         y : array_like
-            An array with shape (n_eval, ) with the Best Linear Unbiased Prediction at x.
+            An array with shape (n_eval, ) with the Best Linear Unbiased
+            Prediction at x.
+
         MSE : array_like, optional (if eval_MSE == True)
             An array with shape (n_eval, ) with the Mean Squared Error at x.
         """
-        
+
         # Check
         if np.any(np.isnan(self.par['beta'])):
-            raise Exception, "Not a valid GaussianProcess. Try fitting it again with different parameters theta"
-        
+            raise Exception("Not a valid GaussianProcess. " \
+                          + "Try fitting it again with different parameters " \
+                          + "theta.")
+
         # Check design & trial sites
         X = np.asanyarray(X, dtype=np.float)
         n_samples = self.X_.shape[0]
         if self.X_.ndim > 1:
             n_features = self.X_.shape[1]
         else:
-            n = 1
+            n_features = 1
         n_eval = X.shape[0]
         if X.ndim > 1:
             n_features_X = X.shape[1]
         else:
             n_features_X = 1
         X = X.reshape(n_eval, n_features_X)
-        
+
         if n_features_X != n_features:
-            raise ValueError, "The number of features in X (X.shape[1] = %d) should match the sample size used for fit() which is %d." % (n_features_X, n_features)
-        
-        if batch_size is None: # No memory management (evaluates all given points in a single batch run)
-            
+            raise ValueError("The number of features in X (X.shape[1] = %d) " \
+                           + "should match the sample size used for fit() " \
+                           + "which is %d." % (n_features_X, n_features))
+
+        if batch_size is None:
+            # No memory management
+            # (evaluates all given points in a single batch run)
+
             # Normalize trial sites
-            X_ = (X - self.X_sc[0,:]) / self.X_sc[1,:]
-            
+            X_ = (X - self.X_sc[0][:]) / self.X_sc[1][:]
+
             # Initialize output
             y = np.zeros(n_eval)
             if eval_MSE:
                 MSE = np.zeros(n_eval)
-            
+
             # Get distances to design sites
-            dx = np.zeros([n_eval*n_samples,n_features])
+            dx = np.zeros([n_eval * n_samples, n_features])
             kk = np.arange(n_samples).astype(int)
             for k in range(n_eval):
-                dx[kk,:] = X_[k,:] - self.X_
+                dx[kk] = X_[k] - self.X_
                 kk = kk + n_samples
-            
+
             # Get regression function and correlation
             f = self.regr(X_)
             r = self.corr(self.theta, dx).reshape(n_eval, n_samples).T
-            
+
             # Scaled predictor
-            y_ = np.matrix(f) * np.matrix(self.par['beta']) + (np.matrix(self.par['gamma']) * np.matrix(r)).T
-            
+            y_ = np.matrix(f) * np.matrix(self.par['beta']) \
+               + (np.matrix(self.par['gamma']) * np.matrix(r)).T
+
             # Predictor
-            y = (self.y_sc[0,:] + self.y_sc[1,:] * np.array(y_)).reshape(n_eval)
-            
+            y = (self.y_sc[0] + self.y_sc[1] * np.array(y_)).ravel()
+
             # Mean Squared Error
             if eval_MSE:
                 par = self.par
                 if par['C'] is None:
                     # Light storage mode (need to recompute C, F, Ft and G)
                     if self.verbose:
-                        print "This GaussianProcess used light storage mode at instanciation. Need to recompute autocorrelation matrix..."
-                    reduced_likelihood_function_value, par = self.reduced_likelihood_function()
-                
+                        print "This GaussianProcess used light storage mode " \
+                            + "at instanciation. Need to recompute " \
+                            + "autocorrelation matrix..."
+                    reduced_likelihood_function_value, par = \
+                        self.reduced_likelihood_function()
+
                 rt = linalg.solve(np.matrix(par['C']), np.matrix(r))
                 if self.beta0 is None:
                     # Universal Kriging
-                    u = linalg.solve(np.matrix(-self.par['G'].T), np.matrix(self.par['Ft']).T * np.matrix(rt) - np.matrix(f).T)
+                    u = - linalg.solve(np.matrix(self.par['G'].T), \
+                                       np.matrix(self.par['Ft']).T * \
+                                       np.matrix(rt) - np.matrix(f).T)
                 else:
                     # Ordinary Kriging
-                    u = 0. * y
-                
-                MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt)**2.) + col_sum(np.array(u)**2.)).T
-                
-                # Mean Squared Error might be slightly negative depending on machine precision
-                # Force to zero!
+                    u = np.zeros(y.shape)
+
+                MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt) ** 2.) \
+                                               + col_sum(np.array(u) ** 2.)).T
+
+                # Mean Squared Error might be slightly negative depending on
+                # machine precision: force to zero!
                 MSE[MSE < 0.] = 0.
-                
+
                 return y, MSE
-                
+
             else:
-                
+
                 return y
-            
-        else: # Memory management
-            
+
+        else:
+            # Memory management
+
             if type(batch_size) is not int or batch_size <= 0:
-                raise Exception, "batch_size must be a positive integer"
-            
+                raise Exception("batch_size must be a positive integer")
+
             if eval_MSE:
-                
-                y, MSE = np.zeros(n_eval), np.zeros(n_eval)
-                for k in range(n_eval/batch_size):
-                    y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])], MSE[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \
-                        self.predict(X[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None)
-                
+
+                y, MSE = np.array([]), np.array([])
+                for k in range(n_eval / batch_size):
+                    batch_from = k * batch_size
+                    batch_to = min([(k + 1) * batch_size + 1, n_eval + 1])
+                    X_batch = X[batch_from:batch_to][:]
+                    y_batch, MSE_batch = \
+                        self.predict(X_batch, eval_MSE=eval_MSE, \
+                                     batch_size=None)
+                    y.append(y_batch)
+                    MSE.append(MSE_batch)
+
                 return y, MSE
-                
+
             else:
-                
-                y = np.zeros(n_eval)
-                for k in range(n_eval/batch_size):
-                    y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \
-                        self.__call__(x[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None)
-                
+
+                y = np.array([])
+                for k in range(n_eval / batch_size):
+                    batch_from = k * batch_size
+                    batch_to = min([(k + 1) * batch_size + 1, n_eval + 1])
+                    X_batch = X[batch_from:batch_to][:]
+                    y_batch = \
+                        self.predict(X_batch, eval_MSE=eval_MSE, \
+                                     batch_size=None)
+
                 return y
-    
-    
-    
-    def reduced_likelihood_function(self, theta = None):
+
+    def reduced_likelihood_function(self, theta=None):
         """
-        This function determines the BLUP parameters and evaluates the reduced likelihood function for the given autocorrelation parameters theta.
-        Maximizing this function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the
-        assumed joint Gaussian distribution of the observations y evaluated onto the design of experiments X.
-        
+        This function determines the BLUP parameters and evaluates the reduced
+        likelihood function for the given autocorrelation parameters theta.
+
+        Maximizing this function wrt the autocorrelation parameters theta is
+        equivalent to maximizing the likelihood of the assumed joint Gaussian
+        distribution of the observations y evaluated onto the design of
+        experiments X.
+
         Parameters
         ----------
         theta : array_like, optional
-            An array containing the autocorrelation parameters at which the Gaussian Process model parameters should be determined.
-            Default uses the built-in autocorrelation parameters (ie theta = self.theta).
-        
+            An array containing the autocorrelation parameters at which the
+            Gaussian Process model parameters should be determined.
+            Default uses the built-in autocorrelation parameters
+            (ie theta = self.theta).
+
         Returns
         -------
+        reduced_likelihood_function_value : double
+            The value of the reduced likelihood function associated to the
+            given autocorrelation parameters theta.
+
         par : dict
-            A dictionary containing the requested Gaussian Process model parameters:
+            A dictionary containing the requested Gaussian Process model
+            parameters:
+
             par['sigma2'] : Gaussian Process variance.
-            par['beta'] : Generalized least-squares regression weights for Universal Kriging or given beta0 for Ordinary Kriging.
+            par['beta'] : Generalized least-squares regression weights for
+                          Universal Kriging or given beta0 for Ordinary
+                          Kriging.
             par['gamma'] : Gaussian Process weights.
             par['C'] : Cholesky decomposition of the correlation matrix [R].
             par['Ft'] : Solution of the linear equation system : [R] x Ft = F
             par['G'] : QR decomposition of the matrix Ft.
-            par['detR'] : Determinant of the correlation matrix raised at power 1/n_samples.
-        reduced_likelihood_function_value : double
-            The value of the reduced likelihood function associated to the given autocorrelation parameters theta.
+            par['detR'] : Determinant of the correlation matrix raised at power
+                          1/n_samples.
         """
-        
+
         if theta is None:
             # Use built-in autocorrelation parameters
             theta = self.theta
-        
-        reduced_likelihood_function_value = -np.inf
-        par = { }
-        
+
+        # Initialize output
+        reduced_likelihood_function_value = - np.inf
+        par = {}
+
         # Retrieve data
         n_samples = self.X_.shape[0]
         D = self.D
         ij = self.ij
         F = self.F
-        
+
         if D is None:
             # Light storage mode (need to recompute D, ij and F)
             if self.X_.ndim > 1:
                 n_features = self.X_.shape[1]
             else:
                 n_features = 1
-            mzmax = n_samples*(n_samples-1)/2
-            ij = zeros([mzmax, n_features])
-            D = zeros([mzmax, n_features])
-            ll = array([-1])
+            mzmax = n_samples * (n_samples - 1) / 2
+            ij = np.zeros([mzmax, n_features])
+            D = np.zeros([mzmax, n_features])
+            ll = np.array([-1])
             for k in range(n_samples-1):
-                ll = ll[-1]+1 + range(n_samples-k-1)
-                ij[ll,:] = np.concatenate([[repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T
-                D[ll,:] = self.X_[k,:] - self.X_[(k+1):n_samples,:]
-            if (min(sum(abs(D),1)) == 0.) and (self.corr != corriid):
-                raise Exception, "Multiple X are not allowed"
+                ll = ll[-1] + 1 + range(n_samples - k - 1)
+                ij[ll] = \
+                    np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \
+                                    [np.arange(k + 1, n_samples).T]]).T
+                D[ll] = self.X_[k] - self.X_[(k + 1):n_samples]
+            if min(sum(abs(D), 1)) == 0. and self.corr != corriid:
+                raise Exception("Multiple X are not allowed")
             F = self.regr(self.X_)
-        
+
         # Set up R
         r = self.corr(theta, D)
         R = np.eye(n_samples) * (1. + self.nugget)
-        R[ij.astype(int)[:,0], ij.astype(int)[:,1]] = r
-        R[ij.astype(int)[:,1], ij.astype(int)[:,0]] = r
-        
+        R[ij.astype(int)[:, 0], ij.astype(int)[:, 1]] = r
+        R[ij.astype(int)[:, 1], ij.astype(int)[:, 0]] = r
+
         # Cholesky decomposition of R
         try:
             C = linalg.cholesky(R, lower=True)
         except linalg.LinAlgError:
             return reduced_likelihood_function_value, par
-        
+
         # Get generalized least squares solution
         Ft = linalg.solve(C, F)
         try:
             Q, G = linalg.qr(Ft, econ=True)
         except:
-            #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177: DeprecationWarning: qr econ argument will be removed after scipy 0.7. The economy transform will then be available through the mode='economic' argument.
+            #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177:
+            # DeprecationWarning: qr econ argument will be removed after scipy
+            # 0.7. The economy transform will then be available through the
+            # mode='economic' argument.
             Q, G = linalg.qr(Ft, mode='economic')
             pass
-        
-        rcondG = 1./(linalg.norm(G)*linalg.norm(linalg.inv(G)))
+
+        rcondG = 1. / (linalg.norm(G) * linalg.norm(linalg.inv(G)))
         if rcondG < 1e-10:
             # Check F
-            condF = linalg.norm(F)*linalg.norm(linalg.inv(F))
+            condF = linalg.norm(F) * linalg.norm(linalg.inv(F))
             if condF > 1e15:
-                raise Exception, "F is too ill conditioned. Poor combination of regression model and observations."
+                raise Exception("F is too ill conditioned. Poor combination " \
+                              + "of regression model and observations.")
             else:
-                # Ft is too ill conditioned
+                # Ft is too ill conditioned, get out (try different theta)
                 return reduced_likelihood_function_value, par
-        
-        Yt = linalg.solve(C,self.y_)
+
+        Yt = linalg.solve(C, self.y_)
         if self.beta0 is None:
             # Universal Kriging
-            beta = linalg.solve(G, np.matrix(Q).T*np.matrix(Yt))
+            beta = linalg.solve(G, np.matrix(Q).T * np.matrix(Yt))
         else:
             # Ordinary Kriging
             beta = np.array(self.beta0)
-        
+
         rho = np.matrix(Yt) - np.matrix(Ft) * np.matrix(beta)
-        normalized_sigma2 = (np.array(rho)**2.).sum(axis=0)/n_samples
-        # The determinant of R is equal to the squared product of the diagonal elements of its Cholesky decomposition C
-        detR = (np.array(np.diag(C))**(2./n_samples)).prod()
+        normalized_sigma2 = (np.array(rho) ** 2.).sum(axis=0) / n_samples
+        # The determinant of R is equal to the squared product of the diagonal
+        # elements of its Cholesky decomposition C
+        detR = (np.array(np.diag(C)) ** (2. / n_samples)).prod()
+
+        # Compute/Organize output
         reduced_likelihood_function_value = - normalized_sigma2.sum() * detR
-        par = { 'sigma2':normalized_sigma2 * self.y_sc[1]**2, \
-                'beta':beta, \
-                'gamma':linalg.solve(C.T,rho).T, \
-                'C':C, \
-                'Ft':Ft, \
-                'G':G }
-        
+        par['sigma2'] = normalized_sigma2 * self.y_sc[1] ** 2
+        par['beta'] = beta
+        par['gamma'] = linalg.solve(C.T, rho).T
+        par['C'] = C
+        par['Ft'] = Ft
+        par['G'] = G
+
         return reduced_likelihood_function_value, par
 
     def arg_max_reduced_likelihood_function(self):
         """
-        This function estimates the autocorrelation parameters theta as the maximizer of the reduced likelihood function.
-        (Minimization of the opposite reduced likelihood function is used for convenience)
-        
+        This function estimates the autocorrelation parameters theta as the
+        maximizer of the reduced likelihood function.
+        (Minimization of the opposite reduced likelihood function is used for
+        convenience)
+
         Parameters
         ----------
         self : All parameters are stored in the Gaussian Process model object.
-        
+
         Returns
         -------
         optimal_theta : array_like
-            The best set of autocorrelation parameters (the sought maximizer of the reduced likelihood function).
+            The best set of autocorrelation parameters (the sought maximizer of
+            the reduced likelihood function).
+
         optimal_reduced_likelihood_function_value : double
             The optimal reduced likelihood function value.
+
         optimal_par : dict
             The BLUP parameters associated to thetaOpt.
         """
-        
+
+        # Initialize output
+        best_optimal_theta = []
+        best_optimal_rlf_value = []
+        best_optimal_par = []
+
         if self.verbose:
-            print "The chosen optimizer is: "+str(self.optimizer)
+            print "The chosen optimizer is: " + str(self.optimizer)
             if self.random_start > 1:
-                print str(self.random_start)+" random starts are required."
-        
+                print str(self.random_start) + " random starts are required."
+
         percent_completed = 0.
-        
+
         if self.optimizer == 'fmin_cobyla':
-            
-            minus_reduced_likelihood_function = lambda log10t: - self.reduced_likelihood_function(theta = 10.**log10t)[0]
-            
+
+            minus_reduced_likelihood_function = lambda log10t: \
+                - self.reduced_likelihood_function(theta=10. ** log10t)[0]
+
             constraints = []
             for i in range(self.theta0.size):
-                constraints.append(lambda log10t: log10t[i] - np.log10(self.thetaL[0,i]))
-                constraints.append(lambda log10t: np.log10(self.thetaU[0,i]) - log10t[i])
-            
+                constraints.append(lambda log10t: \
+                            log10t[i] - np.log10(self.thetaL[0, i]))
+                constraints.append(lambda log10t: \
+                            np.log10(self.thetaU[0, i]) - log10t[i])
+
             for k in range(self.random_start):
-                
+
                 if k == 0:
                     # Use specified starting point as first guess
                     theta0 = self.theta0
                 else:
-                    # Generate a random starting point log10-uniformly distributed between bounds
-                    log10theta0 = np.log10(self.thetaL) + random.rand(self.theta0.size).reshape(self.theta0.shape) * np.log10(self.thetaU/self.thetaL)
-                    theta0 = 10.**log10theta0
-                
-                log10_optimal_theta = optimize.fmin_cobyla(minus_reduced_likelihood_function, np.log10(theta0), constraints, iprint=0)
-                optimal_theta = 10.**log10_optimal_theta
-                optimal_minus_reduced_likelihood_function_value, optimal_par = self.reduced_likelihood_function(theta = optimal_theta)
-                optimal_reduced_likelihood_function_value = - optimal_minus_reduced_likelihood_function_value
-                
+                    # Generate a random starting point log10-uniformly
+                    # distributed between bounds
+                    log10theta0 = np.log10(self.thetaL) \
+                        + rand(self.theta0.size).reshape(self.theta0.shape) \
+                        * np.log10(self.thetaU / self.thetaL)
+                    theta0 = 10. ** log10theta0
+
+                # Run Cobyla
+                log10_optimal_theta = \
+                    optimize.fmin_cobyla(minus_reduced_likelihood_function, \
+                                    np.log10(theta0), constraints, iprint=0)
+
+                optimal_theta = 10. ** log10_optimal_theta
+                optimal_minus_rlf_value, optimal_par = \
+                    self.reduced_likelihood_function(theta=optimal_theta)
+                optimal_rlf_value = - optimal_minus_rlf_value
+
+                # Compare the new optimizer to the best previous one
                 if k > 0:
-                    if optimal_reduced_likelihood_function_value > best_optimal_reduced_likelihood_function_value:
-                        best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value
+                    if optimal_rlf_value > best_optimal_rlf_value:
+                        best_optimal_rlf_value = optimal_rlf_value
                         best_optimal_par = optimal_par
                         best_optimal_theta = optimal_theta
                 else:
-                    best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value
+                    best_optimal_rlf_value = optimal_rlf_value
                     best_optimal_par = optimal_par
                     best_optimal_theta = optimal_theta
                 if self.verbose and self.random_start > 1:
-                    if (20*k)/self.random_start > percent_completed:
-                        percent_completed = (20*k)/self.random_start
-                        print str(5*percent_completed)+"% completed"
-    
+                    if (20 * k) / self.random_start > percent_completed:
+                        percent_completed = (20 * k) / self.random_start
+                        print str(5 * percent_completed) + "% completed"
+
         else:
-            
-            raise NotImplementedError, "This optimizer ('%s') is not implemented yet. Please contribute!" % self.optimizer
-    
-        return best_optimal_theta, best_optimal_reduced_likelihood_function_value, best_optimal_par
-    
+
+            raise NotImplementedError("This optimizer ('%s') is not " \
+                                    + "implemented yet. Please contribute!" \
+                                    % self.optimizer)
+
+        return best_optimal_theta, best_optimal_rlf_value, best_optimal_par
+
     def score(self, X_test, y_test):
         """
-        This score function returns the deviations of the Gaussian Process model evaluated onto a test dataset.
-        
+        This score function returns the deviations of the Gaussian Process
+        model evaluated onto a test dataset.
+
         Parameters
         ----------
         X_test : array_like
             The feature test dataset with shape (n_tests, n_features).
-            
+
         y_test : array_like
             The target test dataset (n_tests, ).
-        
+
         Returns
         -------
         score_values : array_like
-            The deviations between the prediction and the targets : y_pred - y_test.
+            The deviations between the prediction and the targets:
+            y_pred - y_test.
         """
-        
-        return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test
\ No newline at end of file
+
+        return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test
diff --git a/scikits/learn/gaussian_process/regression.py b/scikits/learn/gaussian_process/regression.py
index 5bf8df7d3255b80f998560d6d27608b73ed99392..58bf14948bcd400088ef49b325f4a5e7946de172 100644
--- a/scikits/learn/gaussian_process/regression.py
+++ b/scikits/learn/gaussian_process/regression.py
@@ -7,78 +7,88 @@
 
 import numpy as np
 
+
 ############################
 # Defaut regression models #
 ############################
 
+
 def regpoly0(x):
     """
     Zero order polynomial (constant, p = 1) regression model.
-    
+
     regpoly0 : x --> f(x) = 1
-    
+
     Parameters
     ----------
     x : array_like
-        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the locations x at
+        which the regression model should be evaluated.
+
     Returns
     -------
     f : array_like
-        An array with shape (n_eval, p) with the values of the regression model.
+        An array with shape (n_eval, p) with the values of the regression
+        model.
     """
-    
+
     x = np.asanyarray(x, dtype=np.float)
     n_eval = x.shape[0]
-    f = np.ones([n_eval,1])
-    
+    f = np.ones([n_eval, 1])
+
     return f
 
+
 def regpoly1(x):
     """
     First order polynomial (hyperplane, p = n) regression model.
-    
+
     regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T
-    
+
     Parameters
     ----------
     x : array_like
-        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the locations x at
+        which the regression model should be evaluated.
+
     Returns
     -------
     f : array_like
-        An array with shape (n_eval, p) with the values of the regression model.
+        An array with shape (n_eval, p) with the values of the regression
+        model.
     """
-    
+
     x = np.asanyarray(x, dtype=np.float)
     n_eval = x.shape[0]
-    f = np.hstack([np.ones([n_eval,1]), x])
-    
+    f = np.hstack([np.ones([n_eval, 1]), x])
+
     return f
 
+
 def regpoly2(x):
     """
     Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model.
-    
+
     regpoly2 : x --> f(x) = [ x_i*x_j,  (i,j) = 1,...,n ].T
                                             i > j
-    
+
     Parameters
     ----------
     x : array_like
-        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
-    
+        An array with shape (n_eval, n_features) giving the locations x at
+        which the regression model should be evaluated.
+
     Returns
     -------
     f : array_like
-        An array with shape (n_eval, p) with the values of the regression model.
+        An array with shape (n_eval, p) with the values of the regression
+        model.
     """
-    
+
     x = np.asanyarray(x, dtype=np.float)
     n_eval, n_features = x.shape
-    f = np.hstack([np.ones([n_eval,1]), x])
+    f = np.hstack([np.ones([n_eval, 1]), x])
     for  k in range(n_features):
-        f = np.hstack([f, x[:,k,np.newaxis] * x[:,k:]])
-    
-    return f
\ No newline at end of file
+        f = np.hstack([f, x[:, k, np.newaxis] * x[:, k:]])
+
+    return f
diff --git a/scikits/learn/tests/test_gaussian_process.py b/scikits/learn/tests/test_gaussian_process.py
index bb5ec0c32c5aa8478826b1e51851cb79086e071f..6b0a3a79e7ea68e17e28c497f66c60f3f25fe4d0 100644
--- a/scikits/learn/tests/test_gaussian_process.py
+++ b/scikits/learn/tests/test_gaussian_process.py
@@ -6,43 +6,52 @@ import numpy as np
 from numpy.testing import assert_array_equal, assert_array_almost_equal, \
                           assert_almost_equal, assert_raises, assert_
 
-from .. import gaussian_process, datasets, cross_val, metrics
+from .. import datasets, cross_val, metrics
+from ..gaussian_process import GaussianProcess
 
 diabetes = datasets.load_diabetes()
 
+
 def test_regression_1d_x_sinx():
     """
     MLE estimation of a Gaussian Process model with a squared exponential
     correlation model (correxp2). Check random start optimization.
-    
+
     Test the interpolating property.
     """
-    
+
     f = lambda x: x * np.sin(x)
     X = np.array([1., 3., 5., 6., 7., 8.])
     y = f(X)
-    gp = gaussian_process.GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=10, verbose=False).fit(X, y)
+    gp = GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \
+                         random_start=10, verbose=False).fit(X, y)
     y_pred, MSE = gp.predict(X, eval_MSE=True)
-    
-    assert (np.all(np.abs((y_pred - y) / y)  < 1e-6) and np.all(MSE  < 1e-6))
 
-def test_regression_diabetes(n_jobs = 1, verbose = 0):
+    assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6))
+
+
+def test_regression_diabetes(n_jobs=1, verbose=0):
     """
-    MLE estimation of a Gaussian Process model with an anisotropic squared exponential
-    correlation model.
-    
+    MLE estimation of a Gaussian Process model with an anisotropic squared
+    exponential correlation model.
+
     Test the model using cross-validation module (quite time-consuming).
-    
+
     Poor performance: Leave-one-out estimate of explained variance is about 0.5
     at best... To be investigated!
+
     TODO: find a dataset that would prove GP performance!
     """
-    
+
     X, y = diabetes['data'], diabetes['target']
-    
-    gp = gaussian_process.GaussianProcess(corr=gaussian_process.correxp2, theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y)
 
-    y_pred = cross_val.cross_val_score(gp, X, y=y, cv=cross_val.LeaveOneOut(y.size), n_jobs=n_jobs, verbose=verbose) + y
+    gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y)
+
+    y_pred = cross_val.cross_val_score(gp, X, y=y, \
+                                       cv=cross_val.LeaveOneOut(y.size), \
+                                       n_jobs=n_jobs, verbose=verbose) \
+           + y
+
     Q2 = metrics.explained_variance(y_pred, y)
-    
+
     assert Q2 > 0.45