diff --git a/examples/gaussian_process/plot_gp_diabetes_dataset.py b/examples/gaussian_process/plot_gp_diabetes_dataset.py
index 21b85a0eb18b5e5cb883856ab1bfc839ab4a4d86..8416c954469c91828c7d39ca2b1b02a070bcd764 100644
--- a/examples/gaussian_process/plot_gp_diabetes_dataset.py
+++ b/examples/gaussian_process/plot_gp_diabetes_dataset.py
@@ -2,9 +2,9 @@
 # -*- coding: utf-8 -*-
 
 """
-===============================================
-Gaussian Processes for Machine Learning example
-===============================================
+=============================================================
+Gaussian Processes regression example: the 'diabetes' dataset
+=============================================================
 
 This example consists in fitting a Gaussian Process model onto the diabetes
 dataset.
@@ -39,7 +39,6 @@ gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False)
 
 # Fit the GP model to the data
 gp.fit(X, y)
-gp.par['beta']
 
 # Estimate the leave-one-out predictions using the cross_val module
 n_jobs = 2 # the distributing capacity available on the machine
diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
index 17e2e63919dee277e682b9e5caa7d415f5323708..b2ed888e16044f9fd3c84fcd8625c28942fb4ed2 100644
--- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
+++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py
@@ -2,9 +2,9 @@
 # -*- coding: utf-8 -*-
 
 """
-===============================================
-Gaussian Processes for Machine Learning example
-===============================================
+==============================================================================
+Gaussian Processes classification example: exploiting the probabilistic output
+==============================================================================
 
 A two-dimensional regression exercise with a post-processing allowing for
 probabilistic classification thanks to the Gaussian property of the prediction.
@@ -34,7 +34,7 @@ PHI = lambda x: Grv.cdf(x)
 PHIinv = lambda x: Grv.ppf(x)
 
 # A few constants
-beta0 = 8
+lim = 8
 b, kappa, e = 5, .5, .1
 
 # The function to predict (classification will then consist in predicting
@@ -62,8 +62,8 @@ gp.fit(X, Y)
 
 # Evaluate real function, the prediction and its MSE on a grid
 res = 50
-x1, x2 = np.meshgrid(np.linspace(- beta0, beta0, res), \
-                     np.linspace(- beta0, beta0, res))
+x1, x2 = np.meshgrid(np.linspace(- lim, lim, res), \
+                     np.linspace(- lim, lim, res))
 xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T
 
 YY = g(xx)
@@ -87,7 +87,7 @@ pl.xlabel('$x_1$')
 pl.ylabel('$x_2$')
 
 cax = pl.imshow(np.flipud(PHI(- yy / sigma)), cmap=cm.gray_r, alpha=0.8, \
-                extent=(- beta0, beta0, - beta0, beta0))
+                extent=(- lim, lim, - lim, lim))
 norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9)
 cb = pl.colorbar(cax, ticks=[0., 0.2, 0.4, 0.6, 0.8, 1.], norm=norm)
 cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$')
diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py
index 747e62691cc8eb2585652ee5bf96480f0aa85a4d..ededab866140102d788334b22dc61db9e41886fe 100644
--- a/examples/gaussian_process/plot_gp_regression.py
+++ b/examples/gaussian_process/plot_gp_regression.py
@@ -2,9 +2,9 @@
 # -*- coding: utf-8 -*-
 
 """
-===============================================
-Gaussian Processes for Machine Learning example
-===============================================
+=================================================================
+Gaussian Processes regression example: basic introductory example
+=================================================================
 
 A simple one-dimensional regression exercise with a cubic correlation
 model whose parameters are estimated using the maximum likelihood principle.
@@ -18,7 +18,7 @@ confidence interval.
 # License: BSD style
 
 import numpy as np
-from scikits.learn.gaussian_process import GaussianProcess, corrcubic
+from scikits.learn.gaussian_process import GaussianProcess
 from matplotlib import pyplot as pl
 
 # Print the docstring
@@ -38,7 +38,7 @@ Y = f(X).ravel()
 x = np.atleast_2d(np.linspace(0, 10, 1000)).T
 
 # Instanciate a Gaussian Process model
-gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \
+gp = GaussianProcess(corr='cubic', theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \
                      random_start=100)
 
 # Fit to data using Maximum Likelihood Estimation of the parameters
diff --git a/scikits/learn/gaussian_process/__init__.py b/scikits/learn/gaussian_process/__init__.py
index f2bb68562f77ed4ec5b2bf93f6dec1b357b24703..e6ca5e2bfc809bc111b805cdd6d792cad520224f 100644
--- a/scikits/learn/gaussian_process/__init__.py
+++ b/scikits/learn/gaussian_process/__init__.py
@@ -1,15 +1,32 @@
 #!/usr/bin/python
 # -*- coding: utf-8 -*-
 
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com>
+#         (mostly translation, see implementation details)
+# License: BSD style
+
 """
-    This module contains a contribution to the scikit-learn project that
-    implements Gaussian Process based prediction (also known as Kriging).
+A module that implements scalar Gaussian Process based prediction (also
+known as Kriging).
+
+Contains
+--------
+GaussianProcess: The main class of the module that implements the Gaussian
+                 Process prediction theory.
+regression_models: A submodule that contains the built-in regression models.
+correlation_models: A submodule that contains the built-in correlation models.
+
+Implementation details
+----------------------
+The presentation implementation is based on a translation of the DACE
+Matlab toolbox.
 
-    The present implementation is based on a transliteration of the DACE
-    Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>.
+See references:
+[1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002).
+    DACE - A MATLAB Kriging Toolbox.
+    http://www2.imm.dtu.dk/~hbn/dace/dace.pdf
 """
 
 from .gaussian_process import GaussianProcess
-from .correlation import corrlin, corrcubic, correxp1, \
-                         correxp2, correxpg, corriid
-from .regression import regpoly0, regpoly1, regpoly2
+from . import correlation_models
+from . import regression_models
diff --git a/scikits/learn/gaussian_process/correlation.py b/scikits/learn/gaussian_process/correlation_models.py
similarity index 85%
rename from scikits/learn/gaussian_process/correlation.py
rename to scikits/learn/gaussian_process/correlation_models.py
index 65dcc5b825b2ed21c144b0a407253c02afdfaa18..97b436ed9de6f07a2b8afdbe1b5d324e796e2866 100644
--- a/scikits/learn/gaussian_process/correlation.py
+++ b/scikits/learn/gaussian_process/correlation_models.py
@@ -1,6 +1,14 @@
 #!/usr/bin/python
 # -*- coding: utf-8 -*-
 
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com>
+#         (mostly translation, see implementation details)
+# License: BSD style
+
+"""
+The built-in regression models submodule for the gaussian_process module.
+"""
+
 ################
 # Dependencies #
 ################
@@ -13,14 +21,14 @@ import numpy as np
 #############################
 
 
-def correxp1(theta, d):
+def absolute_exponential(theta, d):
     """
-    Exponential autocorrelation model.
+    Absolute exponential autocorrelation model.
     (Ornstein-Uhlenbeck stochastic process)
 
-                                                   n
-    correxp1 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i| )
-                                                 i = 1
+                                        n
+    theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i| )
+                                      i = 1
 
     Parameters
     ----------
@@ -58,14 +66,14 @@ def correxp1(theta, d):
     return r
 
 
-def correxp2(theta, d):
+def squared_exponential(theta, d):
     """
     Squared exponential correlation model (Radial Basis Function).
     (Infinitely differentiable stochastic process, very smooth)
 
-                                                   n
-    correxp2 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * (dx_i)^2 )
-                                                 i = 1
+                                        n
+    theta, dx --> r(theta, dx) = exp(  sum  - theta_i * (dx_i)^2 )
+                                      i = 1
 
     Parameters
     ----------
@@ -103,15 +111,15 @@ def correxp2(theta, d):
     return r
 
 
-def correxpg(theta, d):
+def generalized_exponential(theta, d):
     """
     Generalized exponential correlation model.
     (Useful when one does not know the smoothness of the function to be
     predicted.)
 
-                                                   n
-    correxpg : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i|^p )
-                                                 i = 1
+                                        n
+    theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i|^p )
+                                      i = 1
 
     Parameters
     ----------
@@ -152,15 +160,15 @@ def correxpg(theta, d):
     return r
 
 
-def corriid(theta, d):
+def pure_nugget(theta, d):
     """
     Spatial independence correlation model (pure nugget).
     (Useful when one wants to solve an ordinary least squares problem!)
 
-                                                    n
-    corriid : theta, dx --> r(theta, dx) = 1 if   sum |dx_i| == 0
-                                                  i = 1
-                                           0 otherwise
+                                         n
+    theta, dx --> r(theta, dx) = 1 if   sum |dx_i| == 0
+                                       i = 1
+                                 0 otherwise
 
     Parameters
     ----------
@@ -191,11 +199,11 @@ def corriid(theta, d):
     return r
 
 
-def corrcubic(theta, d):
+def cubic(theta, d):
     """
     Cubic correlation model.
 
-    corrcubic : theta, dx --> r(theta, dx) =
+    theta, dx --> r(theta, dx) =
           n
         prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) ,  i = 1,...,m
         j = 1
@@ -241,11 +249,11 @@ def corrcubic(theta, d):
     return r
 
 
-def corrlin(theta, d):
+def linear(theta, d):
     """
     Linear correlation model.
 
-    corrlin : theta, dx --> r(theta, dx) =
+    theta, dx --> r(theta, dx) =
           n
         prod max(0, 1 - theta_j*d_ij) ,  i = 1,...,m
         j = 1
diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py
index f5ba3e62fb829d0b573c956625ddc4f1a552a8bf..2622406ac3d525049fbe897ff57511ea320ff43e 100644
--- a/scikits/learn/gaussian_process/gaussian_process.py
+++ b/scikits/learn/gaussian_process/gaussian_process.py
@@ -1,6 +1,10 @@
 #!/usr/bin/python
 # -*- coding: utf-8 -*-
 
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com>
+#         (mostly translation, see implementation details)
+# License: BSD style
+
 ################
 # Dependencies #
 ################
@@ -8,9 +12,9 @@
 import numpy as np
 from scipy import linalg, optimize, rand
 from ..base import BaseEstimator
-from .regression import regpoly0
-from .correlation import correxp2, corriid
-machine_epsilon = np.finfo(np.double).eps
+from . import regression_models as regression
+from . import correlation_models as correlation
+MACHINE_EPSILON = np.finfo(np.double).eps
 if hasattr(linalg, 'solve_triangular'):
     # only in scipy since 0.9
     solve_triangular = linalg.solve_triangular
@@ -26,33 +30,104 @@ else:
 
 class GaussianProcess(BaseEstimator):
     """
-    A class that implements scalar Gaussian Process based prediction (also
-    known as Kriging).
+    The Gaussian Process model class.
+
+    Parameters
+    ----------
+    regr : string or callable, optional
+        A regression function returning an array of outputs of the linear
+        regression functional basis. The number of observations n_samples
+        should be greater than the size p of this basis.
+        Default assumes a simple constant regression trend.
+        Here is the list of built-in regression models:
+            'constant', 'linear', 'quadratic'
+
+    corr : string or callable, optional
+        A stationary autocorrelation function returning the autocorrelation
+        between two points x and x'.
+        Default assumes a squared-exponential autocorrelation model.
+        Here is the list of built-in correlation models:
+            'absolute_exponential', 'squared_exponential',
+            'generalized_exponential', 'cubic', 'linear'
+
+    beta0 : double array_like, optional
+        The regression weight vector to perform Ordinary Kriging (OK).
+        Default assumes Universal Kriging (UK) so that the vector beta of
+        regression weights is estimated using the maximum likelihood
+        principle.
+
+    storage_mode : string, optional
+        A string specifying whether the Cholesky decomposition of the
+        correlation matrix should be stored in the class (storage_mode =
+        'full') or not (storage_mode = 'light').
+        Default assumes storage_mode = 'full', so that the
+        Cholesky decomposition of the correlation matrix is stored.
+        This might be a useful parameter when one is not interested in the
+        MSE and only plan to estimate the BLUP, for which the correlation
+        matrix is not required.
+
+    verbose : boolean, optional
+        A boolean specifying the verbose level.
+        Default is verbose = False.
+
+    theta0 : double array_like, optional
+        An array with shape (n_features, ) or (1, ).
+        The parameters in the autocorrelation model.
+        If thetaL and thetaU are also specified, theta0 is considered as
+        the starting point for the maximum likelihood rstimation of the
+        best set of parameters.
+        Default assumes isotropic autocorrelation model with theta0 = 1e-1.
+
+    thetaL : double array_like, optional
+        An array with shape matching theta0's.
+        Lower bound on the autocorrelation parameters for maximum
+        likelihood estimation.
+        Default is None, so that it skips maximum likelihood estimation and
+        it uses theta0.
+
+    thetaU : double array_like, optional
+        An array with shape matching theta0's.
+        Upper bound on the autocorrelation parameters for maximum
+        likelihood estimation.
+        Default is None, so that it skips maximum likelihood estimation and
+        it uses theta0.
+
+    normalize : boolean, optional
+        Design sites X and observations y are centered and reduced wrt
+        means and standard deviations estimated from the n_samples
+        observations provided.
+        Default is normalize = True so that data is normalized to ease
+        maximum likelihood estimation.
+
+    nugget : double, optional
+        Introduce a nugget effect to allow smooth predictions from noisy
+        data.
+        Default assumes a nugget close to machine precision for the sake of
+        robustness (nugget = 10. * MACHINE_EPSILON).
+
+    optimizer : string, optional
+        A string specifying the optimization algorithm to be used
+        ('fmin_cobyla' is the sole algorithm implemented yet).
+        Default uses 'fmin_cobyla' algorithm from scipy.optimize.
+
+    random_start : int, optional
+        The number of times the Maximum Likelihood Estimation should be
+        performed from a random starting point.
+        The first MLE always uses the specified starting point (theta0),
+        the next starting points are picked at random according to an
+        exponential distribution (log-uniform on [thetaL, thetaU]).
+        Default does not use random starting point (random_start = 1).
 
     Example
     -------
-    import numpy as np
-    from scikits.learn.gaussian_process import GaussianProcess
-    import pylab as pl
-
-    f = lambda x: x * np.sin(x)
-    X = np.array([1., 3., 5., 6., 7., 8.])
-    Y = f(X)
-    gp = GaussianProcess(theta0=1e-1, thetaL=1e-3, thetaU=1e0, \
-                         random_start=100)
-    gp.fit(X, Y)
-    x = np.linspace(0, 10, 1000)
-    y_pred, MSE = gp.predict(x, eval_MSE=True)
-
-    pl.show()
-
-    Methods
-    -------
-    fit(X, y) : self
-        Fit the model.
-
-    predict(X) : array
-        Predict using the model.
+    >>> import numpy as np
+    >>> from scikits.learn.gaussian_process import GaussianProcess
+    >>> f = lambda x: x * np.sin(x)
+    >>> X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
+    >>> Y = f(X).ravel()
+    >>> gp = GaussianProcess(theta0=1e-1, thetaL=1e-3, thetaU=1e0).fit(X, Y)
+    >>> x = np.atleast_2d(np.linspace(0, 10, 1000)).T
+    >>> y_pred, MSE = gp.predict(x, eval_MSE=True)
 
     Implementation details
     ----------------------
@@ -65,151 +140,42 @@ class GaussianProcess(BaseEstimator):
         http://www2.imm.dtu.dk/~hbn/dace/dace.pdf
     """
 
-    def __init__(self, regr=regpoly0, corr=correxp2, beta0=None, \
-                 storage_mode='full', verbose=True, theta0=1e-1, \
-                 thetaL=None, thetaU=None, optimizer='fmin_cobyla', \
-                 random_start=1, normalize=True, \
-                 nugget=10. * machine_epsilon):
-        """
-        The Gaussian Process model constructor.
+    _regression_types = {
+        'constant': regression.constant,
+        'linear': regression.linear,
+        'quadratic': regression.quadratic}
 
-        Parameters
-        ----------
-        regr : function, optional
-            A regression function returning an array of outputs of the linear
-            regression functional basis. The number of observations n_samples
-            should be greater than the size p of this basis.
-            Default assumes a simple constant regression trend (see regpoly0).
-
-        corr : function, optional
-            A stationary autocorrelation function returning the autocorrelation
-            between two points x and x'.
-            Default assumes a squared-exponential autocorrelation model (see
-            correxp2).
-
-        beta0 : double array_like, optional
-            The regression weight vector to perform Ordinary Kriging (OK).
-            Default assumes Universal Kriging (UK) so that the vector beta of
-            regression weights is estimated using the maximum likelihood
-            principle.
-
-        storage_mode : string, optional
-            A string specifying whether the Cholesky decomposition of the
-            correlation matrix should be stored in the class (storage_mode =
-            'full') or not (storage_mode = 'light').
-            Default assumes storage_mode = 'full', so that the
-            Cholesky decomposition of the correlation matrix is stored.
-            This might be a useful parameter when one is not interested in the
-            MSE and only plan to estimate the BLUP, for which the correlation
-            matrix is not required.
-
-        verbose : boolean, optional
-            A boolean specifying the verbose level.
-            Default is verbose = True.
-
-        theta0 : double array_like, optional
-            An array with shape (n_features, ) or (1, ).
-            The parameters in the autocorrelation model.
-            If thetaL and thetaU are also specified, theta0 is considered as
-            the starting point for the maximum likelihood rstimation of the
-            best set of parameters.
-            Default assumes isotropic autocorrelation model with theta0 = 1e-1.
-
-        thetaL : double array_like, optional
-            An array with shape matching theta0's.
-            Lower bound on the autocorrelation parameters for maximum
-            likelihood estimation.
-            Default is None, so that it skips maximum likelihood estimation and
-            it uses theta0.
-
-        thetaU : double array_like, optional
-            An array with shape matching theta0's.
-            Upper bound on the autocorrelation parameters for maximum
-            likelihood estimation.
-            Default is None, so that it skips maximum likelihood estimation and
-            it uses theta0.
-
-        normalize : boolean, optional
-            Design sites X and observations y are centered and reduced wrt
-            means and standard deviations estimated from the n_samples
-            observations provided.
-            Default is normalize = True so that data is normalized to ease
-            maximum likelihood estimation.
-
-        nugget : double, optional
-            Introduce a nugget effect to allow smooth predictions from noisy
-            data.
-            Default assumes a nugget close to machine precision for the sake of
-            robustness (nugget = 10.*machine_epsilon).
-
-        optimizer : string, optional
-            A string specifying the optimization algorithm to be used
-            ('fmin_cobyla' is the sole algorithm implemented yet).
-            Default uses 'fmin_cobyla' algorithm from scipy.optimize.
-
-        random_start : int, optional
-            The number of times the Maximum Likelihood Estimation should be
-            performed from a random starting point.
-            The first MLE always uses the specified starting point (theta0),
-            the next starting points are picked at random according to an
-            exponential distribution (log-uniform on [thetaL, thetaU]).
-            Default does not use random starting point (random_start = 1).
+    _correlation_types = {
+        'absolute_exponential': correlation.absolute_exponential,
+        'squared_exponential': correlation.squared_exponential,
+        'generalized_exponential': correlation.generalized_exponential,
+        'cubic': correlation.cubic,
+        'linear': correlation.linear}
 
-        Returns
-        -------
-        gp : self
-            A Gaussian Process model object awaiting data to be fitted to.
-        """
+    _optimizer_types = [
+        'fmin_cobyla']
+
+    def __init__(self, regr='constant', corr='squared_exponential', beta0=None,
+                 storage_mode='full', verbose=False, theta0=1e-1,
+                 thetaL=None, thetaU=None, optimizer='fmin_cobyla',
+                 random_start=1, normalize=True,
+                 nugget=10. * MACHINE_EPSILON):
 
         self.regr = regr
         self.corr = corr
         self.beta0 = beta0
-
-        # Check storage mode
-        if storage_mode != 'full' and storage_mode != 'light':
-            if storage_mode == 'sparse':
-                raise NotImplementedError("The 'sparse' storage mode is not " \
-                      + "supported yet. Please contribute!")
-            else:
-                raise ValueError("Storage mode should either be 'full' or " \
-                      + "'light'. Unknown storage mode: " + str(storage_mode))
-        else:
-            self.storage_mode = storage_mode
-
+        self.storage_mode = storage_mode
         self.verbose = verbose
-
-        # Check correlation parameters
-        self.theta0 = np.atleast_2d(theta0)
-        self.thetaL, self.thetaU = thetaL, thetaU
-        lth = self.theta0.size
-
-        if self.thetaL is not None and self.thetaU is not None:
-            # Parameters optimization case
-            self.thetaL = np.atleast_2d(thetaL)
-            self.thetaU = np.atleast_2d(thetaU)
-
-            if self.thetaL.size != lth or self.thetaU.size != lth:
-                raise ValueError("theta0, thetaL and thetaU must have the " \
-                      + "same length")
-            if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL):
-                raise ValueError("The bounds must satisfy O < thetaL <= " \
-                      + "thetaU")
-
-        elif self.thetaL is None and self.thetaU is None:
-            # Given parameters case
-            if np.any(self.theta0 <= 0):
-                raise ValueError("theta0 must be strictly positive")
-
-        elif self.thetaL is None or self.thetaU is None:
-            # Error
-            raise Exception("thetaL and thetaU should either be both or " \
-                  + "neither specified")
-
-        # Store other parameters
+        self.theta0 = theta0
+        self.thetaL = thetaL
+        self.thetaU = thetaU
         self.normalize = normalize
         self.nugget = nugget
         self.optimizer = optimizer
-        self.random_start = int(random_start)
+        self.random_start = random_start
+
+        # Run input checks
+        self._check_params()
 
     def fit(self, X, y):
         """
@@ -232,6 +198,9 @@ class GaussianProcess(BaseEstimator):
             predictions.
         """
 
+        # Run input checks
+        self._check_params()
+
         # Force data to 2D numpy.array
         X = np.atleast_2d(X)
         y = np.asanyarray(y)[:, np.newaxis]
@@ -241,7 +210,7 @@ class GaussianProcess(BaseEstimator):
         n_samples_y = y.shape[0]
 
         if n_samples_X != n_samples_y:
-            raise Exception("X and y must have the same number of rows!")
+            raise Exception("X and y must have the same number of rows.")
         else:
             n_samples = n_samples_X
 
@@ -269,10 +238,11 @@ class GaussianProcess(BaseEstimator):
         ll = np.array([-1])
         for k in range(n_samples-1):
             ll = ll[-1] + 1 + range(n_samples - k - 1)
-            ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \
+            ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)],
                                      [np.arange(k + 1, n_samples).T]]).T
             D[ll] = X[k] - X[(k + 1):n_samples]
-        if np.min(np.sum(np.abs(D), 1)) == 0. and self.corr != corriid:
+        if np.min(np.sum(np.abs(D), 1)) == 0. \
+                                    and self.corr != correlation.pure_nugget:
             raise Exception("Multiple X are not allowed")
 
         # Regression matrix and parameters
@@ -283,16 +253,18 @@ class GaussianProcess(BaseEstimator):
         else:
             p = 1
         if n_samples_F != n_samples:
-            raise Exception("Number of rows in F and X do not match. Most " \
-                          + "likely something is going wrong with the " \
+            raise Exception("Number of rows in F and X do not match. Most "
+                          + "likely something is going wrong with the "
                           + "regression model.")
         if p > n_samples_F:
-            raise Exception("Ordinary least squares problem is undetermined " \
-                           + "n_samples=%d must be greater than the " \
-                           + "regression model size p=%d!" % (n_samples, p))
-        if self.beta0 is not None \
-           and (self.beta0.shape[0] != p or self.beta0.ndim > 1):
-            raise Exception("Shapes of beta0 and F do not match.")
+            raise Exception(("Ordinary least squares problem is undetermined "
+                           + "n_samples=%d must be greater than the "
+                           + "regression model size p=%d.") % (n_samples, p))
+        if self.beta0 is not None:
+            if self.beta0.shape[0] != p:
+                import pdb
+                pdb.set_trace()
+                raise Exception("Shapes of beta0 and F do not match.")
 
         # Set attributes
         self.X = X
@@ -307,37 +279,44 @@ class GaussianProcess(BaseEstimator):
         if self.thetaL is not None and self.thetaU is not None:
             # Maximum Likelihood Estimation of the parameters
             if self.verbose:
-                print "Performing Maximum Likelihood Estimation of the " \
-                    + "autocorrelation parameters..."
-            self.theta, self.reduced_likelihood_function_value, self.par = \
+                print("Performing Maximum Likelihood Estimation of the "
+                    + "autocorrelation parameters...")
+            self.theta, self.reduced_likelihood_function_value, par = \
                 self.arg_max_reduced_likelihood_function()
             if np.isinf(self.reduced_likelihood_function_value):
-                raise Exception("Bad parameter region. " \
+                raise Exception("Bad parameter region. "
                               + "Try increasing upper bound")
 
         else:
             # Given parameters
             if self.verbose:
-                print "Given autocorrelation parameters. " \
-                    + "Computing Gaussian Process model parameters..."
+                print("Given autocorrelation parameters. "
+                    + "Computing Gaussian Process model parameters...")
             self.theta = self.theta0
-            self.reduced_likelihood_function_value, self.par = \
+            self.reduced_likelihood_function_value, par = \
                 self.reduced_likelihood_function()
             if np.isinf(self.reduced_likelihood_function_value):
-                raise Exception("Bad point. Try increasing theta0")
+                raise Exception("Bad point. Try increasing theta0.")
+
+        self.beta = par['beta']
+        self.gamma = par['gamma']
+        self.sigma2 = par['sigma2']
+        self.C = par['C']
+        self.Ft = par['Ft']
+        self.G = par['G']
 
         if self.storage_mode == 'light':
             # Delete heavy data (it will be computed again if required)
             # (it is required only when MSE is wanted in self.predict)
             if self.verbose:
-                print "Light storage mode specified. " \
-                    + "Flushing autocorrelation matrix..."
+                print("Light storage mode specified. "
+                    + "Flushing autocorrelation matrix...")
             self.D = None
             self.ij = None
             self.F = None
-            self.par['C'] = None
-            self.par['Ft'] = None
-            self.par['G'] = None
+            self.C = None
+            self.Ft = None
+            self.G = None
 
         return self
 
@@ -357,10 +336,6 @@ class GaussianProcess(BaseEstimator):
             Default assumes evalMSE = False and evaluates only the BLUP (mean
             prediction).
 
-        verbose : boolean, optional
-            A boolean specifying the verbose level.
-            Default is verbose = True.
-
         batch_size : integer, optional
             An integer giving the maximum number of points that can be
             evaluated simulatneously (depending on the available memory).
@@ -377,11 +352,8 @@ class GaussianProcess(BaseEstimator):
             An array with shape (n_eval, ) with the Mean Squared Error at x.
         """
 
-        # Check itself
-        if np.any(np.isnan(self.par['beta'])):
-            raise Exception("Not a valid GaussianProcess. " \
-                          + "Try fitting it again with different parameters " \
-                          + "theta.")
+        # Run input checks
+        self._check_params()
 
         # Check design & trial sites
         X = np.atleast_2d(X)
@@ -389,9 +361,9 @@ class GaussianProcess(BaseEstimator):
         n_samples, n_features = self.X.shape
 
         if n_features_X != n_features:
-            raise ValueError("The number of features in X (X.shape[1] = %d) " \
-                           % n_features_X + "should match the sample size " \
-                           + "used for fit() which is %d." % n_features)
+            raise ValueError(("The number of features in X (X.shape[1] = %d) "
+                           + "should match the sample size used for fit() "
+                           + "which is %d.") % (n_features_X, n_features))
 
         if batch_size is None:
             # No memory management
@@ -417,35 +389,38 @@ class GaussianProcess(BaseEstimator):
             r = self.corr(self.theta, dx).reshape(n_eval, n_samples)
 
             # Scaled predictor
-            y_ = np.dot(f, self.par['beta']) \
-               + np.dot(r, self.par['gamma'])
+            y_ = np.dot(f, self.beta) \
+               + np.dot(r, self.gamma)
 
             # Predictor
             y = (self.y_sc[0] + self.y_sc[1] * y_).ravel()
 
             # Mean Squared Error
             if eval_MSE:
-                par = self.par
-                if par['C'] is None:
+                C = self.C
+                if C is None:
                     # Light storage mode (need to recompute C, F, Ft and G)
                     if self.verbose:
-                        print "This GaussianProcess used light storage mode " \
-                            + "at instanciation. Need to recompute " \
-                            + "autocorrelation matrix..."
+                        print("This GaussianProcess used 'light' storage mode "
+                            + "at instanciation. Need to recompute "
+                            + "autocorrelation matrix...")
                     reduced_likelihood_function_value, par = \
                         self.reduced_likelihood_function()
+                    self.C = par['C']
+                    self.Ft = par['Ft']
+                    self.G = par['G']
 
-                rt = solve_triangular(par['C'], r.T)
+                rt = solve_triangular(C, r.T)
                 if self.beta0 is None:
                     # Universal Kriging
-                    u = solve_triangular(self.par['G'].T, \
-                                         np.dot(self.par['Ft'].T, rt) - f.T)
+                    u = solve_triangular(self.G.T,
+                                         np.dot(self.Ft.T, rt) - f.T)
                 else:
                     # Ordinary Kriging
                     u = np.zeros(y.shape)
 
-                MSE = self.par['sigma2'] * (1. - (rt ** 2.).sum(axis=0) \
-                                               + (u ** 2.).sum(axis=0))
+                MSE = self.sigma2 * (1. - (rt ** 2.).sum(axis=0)
+                                        + (u ** 2.).sum(axis=0))
 
                 # Mean Squared Error might be slightly negative depending on
                 # machine precision: force to zero!
@@ -465,29 +440,25 @@ class GaussianProcess(BaseEstimator):
 
             if eval_MSE:
 
-                y, MSE = np.array([]), np.array([])
+                y, MSE = np.zeros(n_eval), np.zeros(n_eval)
                 for k in range(n_eval / batch_size):
                     batch_from = k * batch_size
                     batch_to = min([(k + 1) * batch_size + 1, n_eval + 1])
-                    X_batch = X[batch_from:batch_to][:]
-                    y_batch, MSE_batch = \
-                        self.predict(X_batch, eval_MSE=eval_MSE, \
-                                     batch_size=None)
-                    y.append(y_batch)
-                    MSE.append(MSE_batch)
+                    y[batch_from:batch_to], MSE[batch_from:batch_to] = \
+                        self.predict(X[batch_from:batch_to][:],
+                                     eval_MSE=eval_MSE, batch_size=None)
 
                 return y, MSE
 
             else:
 
-                y = np.array([])
+                y = np.zeros(n_eval)
                 for k in range(n_eval / batch_size):
                     batch_from = k * batch_size
                     batch_to = min([(k + 1) * batch_size + 1, n_eval + 1])
-                    X_batch = X[batch_from:batch_to][:]
-                    y_batch = \
-                        self.predict(X_batch, eval_MSE=eval_MSE, \
-                                     batch_size=None)
+                    y[batch_from:batch_to] = \
+                        self.predict(X[batch_from:batch_to][:],
+                                     eval_MSE=eval_MSE, batch_size=None)
 
                 return y
 
@@ -527,8 +498,6 @@ class GaussianProcess(BaseEstimator):
             par['C'] : Cholesky decomposition of the correlation matrix [R].
             par['Ft'] : Solution of the linear equation system : [R] x Ft = F
             par['G'] : QR decomposition of the matrix Ft.
-            par['detR'] : Determinant of the correlation matrix raised at power
-                          1/n_samples.
         """
 
         if theta is None:
@@ -558,10 +527,11 @@ class GaussianProcess(BaseEstimator):
             for k in range(n_samples-1):
                 ll = ll[-1] + 1 + range(n_samples - k - 1)
                 ij[ll] = \
-                    np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \
+                    np.concatenate([[np.repeat(k, n_samples - k - 1, 0)],
                                     [np.arange(k + 1, n_samples).T]]).T
                 D[ll] = self.X[k] - self.X[(k + 1):n_samples]
-            if min(sum(abs(D), 1)) == 0. and self.corr != corriid:
+            if min(sum(abs(D), 1)) == 0. \
+                                    and self.corr != correlation.pure_nugget:
                 raise Exception("Multiple X are not allowed")
             F = self.regr(self.X)
 
@@ -596,7 +566,7 @@ class GaussianProcess(BaseEstimator):
             sv = linalg.svd(F, compute_uv=False)
             condF = sv[0] / sv[-1]
             if condF > 1e15:
-                raise Exception("F is too ill conditioned. Poor combination " \
+                raise Exception("F is too ill conditioned. Poor combination "
                               + "of regression model and observations.")
             else:
                 # Ft is too ill conditioned, get out (try different theta)
@@ -690,7 +660,7 @@ class GaussianProcess(BaseEstimator):
 
                 # Run Cobyla
                 log10_optimal_theta = \
-                    optimize.fmin_cobyla(minus_reduced_likelihood_function, \
+                    optimize.fmin_cobyla(minus_reduced_likelihood_function,
                                     np.log10(theta0), constraints, iprint=0)
 
                 optimal_theta = 10. ** log10_optimal_theta
@@ -715,15 +685,15 @@ class GaussianProcess(BaseEstimator):
 
         else:
 
-            raise NotImplementedError("This optimizer ('%s') is not " \
-                                    + "implemented yet. Please contribute!" \
+            raise NotImplementedError("This optimizer ('%s') is not "
+                                    + "implemented yet. Please contribute!"
                                     % self.optimizer)
 
         return best_optimal_theta, best_optimal_rlf_value, best_optimal_par
 
     def score(self, X_test, y_test):
         """
-        This score function returns the deviations of the Gaussian Process
+        This score function returns the mean deviation of the Gaussian Process
         model evaluated onto a test dataset.
 
         Parameters
@@ -736,9 +706,81 @@ class GaussianProcess(BaseEstimator):
 
         Returns
         -------
-        score_values : array_like
-            The deviations between the prediction and the targets:
-            y_pred - y_test.
+        score_value : array_like
+            The mean of the deviations between the prediction and the targets:
+            mean(y_pred - y_test).
         """
 
-        return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test
+        return (self.predict(X_test, eval_MSE=False) - y_test).mean()
+
+    def _check_params(self):
+
+        # Check regression model
+        if not callable(self.regr):
+            if self.regr in self._regression_types:
+                self.regr = self._regression_types[self.regr]
+            else:
+                raise ValueError(("regr should be one of %s or callable, "
+                               + "%s was given.")
+                               % (self._regression_types.keys(), self.regr))
+
+        # Check regression weights if given (Ordinary Kriging)
+        if self.beta0 is not None:
+            self.beta0 = np.atleast_2d(self.beta0)
+            if self.beta0.shape[1] != 1:
+                # Force to column vector
+                self.beta0 = self.beta0.T
+
+        # Check correlation model
+        if not callable(self.corr):
+            if self.corr in self._correlation_types:
+                self.corr = self._correlation_types[self.corr]
+            else:
+                raise ValueError(("corr should be one of %s or callable, "
+                               + "%s was given.")
+                               % (self._correlation_types.keys(), self.corr))
+
+        # Check storage mode
+        if self.storage_mode != 'full' and self.storage_mode != 'light':
+            raise ValueError("Storage mode should either be 'full' or "
+                           + "'light', %s was given." % self.storage_mode)
+
+        # Check correlation parameters
+        self.theta0 = np.atleast_2d(self.theta0)
+        lth = self.theta0.size
+
+        if self.thetaL is not None and self.thetaU is not None:
+            self.thetaL = np.atleast_2d(self.thetaL)
+            self.thetaU = np.atleast_2d(self.thetaU)
+            if self.thetaL.size != lth or self.thetaU.size != lth:
+                raise ValueError("theta0, thetaL and thetaU must have the "
+                               + "same length.")
+            if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL):
+                raise ValueError("The bounds must satisfy O < thetaL <= "
+                               + "thetaU.")
+
+        elif self.thetaL is None and self.thetaU is None:
+            if np.any(self.theta0 <= 0):
+                raise ValueError("theta0 must be strictly positive.")
+
+        elif self.thetaL is None or self.thetaU is None:
+            raise ValueError("thetaL and thetaU should either be both or "
+                           + "neither specified.")
+
+        # Force verbose type to bool
+        self.verbose = bool(self.verbose)
+
+        # Force normalize type to bool
+        self.normalize = bool(self.normalize)
+
+        # Check nugget value
+        if self.nugget < 0.:
+            raise ValueError("nugget must be positive or zero.")
+
+        # Check optimizer
+        if not self.optimizer in self._optimizer_types:
+            raise ValueError("optimizer should be one of %s"
+                           % self._optimizer_types)
+
+        # Force random_start type to int
+        self.random_start = int(self.random_start)
diff --git a/scikits/learn/gaussian_process/regression.py b/scikits/learn/gaussian_process/regression_models.py
similarity index 73%
rename from scikits/learn/gaussian_process/regression.py
rename to scikits/learn/gaussian_process/regression_models.py
index 58bf14948bcd400088ef49b325f4a5e7946de172..68c5a196f77e92a88fcc5e96eac30b289b8cf9d0 100644
--- a/scikits/learn/gaussian_process/regression.py
+++ b/scikits/learn/gaussian_process/regression_models.py
@@ -1,6 +1,14 @@
 #!/usr/bin/python
 # -*- coding: utf-8 -*-
 
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com>
+#         (mostly translation, see implementation details)
+# License: BSD style
+
+"""
+The built-in regression models submodule for the gaussian_process module.
+"""
+
 ################
 # Dependencies #
 ################
@@ -13,11 +21,11 @@ import numpy as np
 ############################
 
 
-def regpoly0(x):
+def constant(x):
     """
     Zero order polynomial (constant, p = 1) regression model.
 
-    regpoly0 : x --> f(x) = 1
+    x --> f(x) = 1
 
     Parameters
     ----------
@@ -39,11 +47,11 @@ def regpoly0(x):
     return f
 
 
-def regpoly1(x):
+def linear(x):
     """
-    First order polynomial (hyperplane, p = n) regression model.
+    First order polynomial (linear, p = n+1) regression model.
 
-    regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T
+    x --> f(x) = [ 1, x_1, ..., x_n ].T
 
     Parameters
     ----------
@@ -65,12 +73,12 @@ def regpoly1(x):
     return f
 
 
-def regpoly2(x):
+def quadratic(x):
     """
-    Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model.
+    Second order polynomial (quadratic, p = n*(n-1)/2+n+1) regression model.
 
-    regpoly2 : x --> f(x) = [ x_i*x_j,  (i,j) = 1,...,n ].T
-                                            i > j
+    x --> f(x) = [ 1, { x_i, i = 1,...,n }, { x_i * x_j,  (i,j) = 1,...,n } ].T
+                                                          i > j
 
     Parameters
     ----------
diff --git a/scikits/learn/gaussian_process/tests/__init__.py b/scikits/learn/gaussian_process/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scikits/learn/gaussian_process/tests/test_gaussian_process.py b/scikits/learn/gaussian_process/tests/test_gaussian_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f71070fa999dbd11f61ee9a5fcf012265b58d10
--- /dev/null
+++ b/scikits/learn/gaussian_process/tests/test_gaussian_process.py
@@ -0,0 +1,87 @@
+"""
+Testing for Gaussian Process module (scikits.learn.gaussian_process)
+"""
+
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com>
+# License: BSD style
+
+import numpy as np
+
+from .. import GaussianProcess
+from .. import regression_models as regression
+from .. import correlation_models as correlation
+
+
+def test_1d(regr=regression.constant, corr=correlation.squared_exponential,
+            random_start=10, beta0=None):
+    """
+    MLE estimation of a one-dimensional Gaussian Process model.
+    Check random start optimization.
+
+    Test the interpolating property.
+    """
+
+    f = lambda x: x * np.sin(x)
+    X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
+    y = f(X).ravel()
+    gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0,
+                         theta0=1e-2, thetaL=1e-4, thetaU=1e-1,
+                         random_start=random_start, verbose=False).fit(X, y)
+    y_pred, MSE = gp.predict(X, eval_MSE=True)
+
+    assert np.allclose(y_pred, y) and np.allclose(MSE, 0.)
+
+
+def test_2d(regr=regression.constant, corr=correlation.squared_exponential,
+            random_start=10, beta0=None):
+    """
+    MLE estimation of a two-dimensional Gaussian Process model accounting for
+    anisotropy. Check random start optimization.
+
+    Test the interpolating property.
+    """
+
+    b, kappa, e = 5., .5, .1
+    g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2.
+    X = np.array([[-4.61611719, -6.00099547],
+                  [4.10469096, 5.32782448],
+                  [0.00000000, -0.50000000],
+                  [-6.17289014, -4.6984743],
+                  [1.3109306, -6.93271427],
+                  [-5.03823144, 3.10584743],
+                  [-2.87600388, 6.74310541],
+                  [5.21301203, 4.26386883]])
+    y = g(X).ravel()
+    gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0,
+                         theta0=[1e-2] * 2, thetaL=[1e-4] * 2,
+                         thetaU=[1e-1] * 2,
+                         random_start=random_start, verbose=False)
+    gp.fit(X, y)
+    y_pred, MSE = gp.predict(X, eval_MSE=True)
+
+    assert np.allclose(y_pred, y) and np.allclose(MSE, 0.)
+
+
+def test_more_builtin_correlation_models(random_start=1):
+    """
+    Repeat test_1d and test_2d for several built-in correlation
+    models specified as strings.
+    """
+    all_corr = ['absolute_exponential', 'squared_exponential', 'cubic',
+                'linear']
+
+    for corr in all_corr:
+        test_1d(regr='constant', corr=corr, random_start=random_start)
+        test_2d(regr='constant', corr=corr, random_start=random_start)
+
+
+def test_ordinary_kriging():
+    """
+    Repeat test_1d and test_2d with given regression weights (beta0) for
+    different regression models (Ordinary Kriging).
+    """
+
+    test_1d(regr='linear', beta0=[0., 0.5])
+    test_1d(regr='quadratic', beta0=[0., 0.5, 0.5])
+    test_2d(regr='linear', beta0=[0., 0.5, 0.5])
+    test_2d(regr='quadratic', beta0=[0., 0.5, 0.5, 0.5, 0.5, 0.5])
diff --git a/scikits/learn/tests/test_gaussian_process.py b/scikits/learn/tests/test_gaussian_process.py
deleted file mode 100644
index 99e81556979a7f219f34b71bf8a967164754a385..0000000000000000000000000000000000000000
--- a/scikits/learn/tests/test_gaussian_process.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Testing for Gaussian Process module (scikits.learn.gaussian_process)
-"""
-
-import numpy as np
-from numpy.testing import assert_array_equal, assert_array_almost_equal, \
-                          assert_almost_equal, assert_raises, assert_
-
-from ..gaussian_process import GaussianProcess
-
-
-def test_regression_1d_x_sinx():
-    """
-    MLE estimation of a Gaussian Process model with a squared exponential
-    correlation model (correxp2). Check random start optimization.
-
-    Test the interpolating property.
-    """
-
-    f = lambda x: x * np.sin(x)
-    X = np.array([1., 3., 5., 6., 7., 8.])
-    y = f(X)
-    gp = GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \
-                         random_start=10, verbose=False).fit(X, y)
-    y_pred, MSE = gp.predict(X, eval_MSE=True)
-
-    assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6))