diff --git a/doc/modules/gpml.rst b/doc/modules/gpml.rst
new file mode 100644
index 0000000000000000000000000000000000000000..23e99f9f85ec30958fb01311858f255af360ca93
--- /dev/null
+++ b/doc/modules/gpml.rst
@@ -0,0 +1,213 @@
+=======================================
+Gaussian Processes for Machine Learning
+=======================================
+
+.. currentmodule:: scikits.learn.gpml
+
+**Gaussian Processes for Machine Learning (GPML)** is a supervised learning
+method used for *regression*. It can also be used for *probabilistic classification*, 
+but it is only a post-processing of the *regression* exercise.
+
+The advantages of Gaussian Processes for Machine Learning are:
+      
+    - The prediction interpolates the observations (at least for regular
+      correlation models).
+
+    - The prediction is probabilistic (Gaussian) so that one can compute 
+      empirical confidence intervals and exceedence probabilities that might be 
+      used to refit (online fitting, adaptive fitting) the prediction in some 
+      region of interest.
+
+    - Versatile: different :ref:`linear regression models <linear_regression_models>` and
+      :ref:`correlation models <correlation_models>` can be specified. Common models are
+      provided, but it is also possible to specify custom models.
+
+The disadvantages of Gaussian Processes for Machine Learning include:
+
+    - It is not sparse. It uses the whole samples/features information to 
+      perform the prediction.
+
+    - It loses efficiency in high dimensional spaces -- namely when the number
+      of features exceeds a few dozens. It might indeed give poor performance
+      and it becomes computationally inefficient.
+
+    - Classification is only a post-processing, meaning that one first need
+      to solve a regression problem by providing the complete scalar float
+      precision output :math:`y` of the computer experiment one attempt to model.
+
+Thanks to the Gaussian property of the prediction, it has been given varied 
+applications: e.g. for global optimization, probabilistic classification.
+
+
+Mathematical formulation
+========================
+
+The initial assumption
+~~~~~~~~~~~~~~~~~~~~~~
+
+Suppose one wants to model the output of a computer experiment, say a
+mathematical function:
+
+.. math::
+
+        g: & \mathbb{R}^{n_{\rm features}} \rightarrow \mathbb{R} \\
+           & X \mapsto y = g(X)
+
+GPML starts with the assumption that this function is a conditionnal sample path
+of a Gaussian process :math:`G` which is additionally assumed to read as follows:
+
+.. math::
+
+        G(X) = f(X)^T \beta + Z(X)
+
+where :math:`f(X)^T \beta` is a linear regression model and :math:`Z(X)` is a zero-mean Gaussian
+process with a fully stationary covariance function:
+
+.. math::
+
+        C(X, X') = \sigma^2 R(|X - X'|)
+
+:math:`\sigma^2` being its variance and :math:`R` being the correlation function which solely
+depends on the absolute relative distance between each sample -- possibly featurewise.
+
+From this basic formulation, note that GPML is nothing but an extension of a
+basic least squares linear regression problem:
+
+.. math::
+
+        g(X) \approx f(X)^T \beta
+
+Except we additionaly assume some spatial coherence (correlation) between the 
+samples dictated by the correlation function. Indeed, ordinary least squares assumes the
+correlation model :math:`R(|X - X'|)` is one when :math:`X = X'` and zero otherwise : a *dirac* correlation
+model -- sometimes referred to as a *nugget* correlation model in the kriging literature.
+
+
+The best linear unbiased prediction (BLUP)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We now derive the *best linear unbiased prediction* of the sample path :math:`g`
+conditioned by the observations:
+
+.. math::
+
+    \hat{G}(X) \sim G(X | y_1 = g(X_1), ..., y_{n_{\rm samples}} = g(X_{n_{\rm samples}}))
+
+It is derived from its *given properties*:
+
+- It is linear (a linear combination of the observations)
+
+.. math::
+
+    \hat{G}(X) \equiv a(X)^T y
+
+- It is unbiased
+
+.. math::
+
+    \mathbb{E}[G(X) - \hat{G}(X)] = 0
+
+- It is the best (in the Mean Squared Error sense)
+
+.. math::
+
+    \hat{G}(X)^* = \arg \min\limits_{\hat{G}(X)} \; \mathbb{E}[(G(X) - \hat{G}(X))^2]
+
+So that the optimal weight vector :math:`a(X)` is solution of the following equality constrained optimization
+problem:
+
+.. math::
+
+    a(X)^* = \arg \min\limits_{a(X)} & \; \mathbb{E}[(G(X) - a(X)^T y)^2] \\
+                       {\rm s. t.} & \; \mathbb{E}[G(X) - a(X)^T y] = 0
+
+Rewriting this constrained optimization problem in the form of a Lagrangian and looking further for
+the first order optimality conditions to be satisfied, one ends up with a closed form expression for the 
+sought predictor -- see references for the complete proof.
+
+In the end, the BLUP is shown to be a Gaussian random variate whose moments expressions are given in reference.
+   
+
+
+The empirical best linear unbiased predictor (EBLUP)
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Until now, both the autocorrelation and regression models were assumed given. In 
+practice however they are never known in advance so that one has to make (motivated)
+empirical choices for these models :ref:`correlation_models`.
+
+Provided these choices are made, one should estimate the remaining unknown 
+parameters involved in the BLUP. To do so, one uses the set of provided observations 
+in conjunction with some inference technique. The present implementation, which is based
+on the DACE's Matlab toolbox uses the *maximum likelihood estimation* technique.
+
+For a more comprehensive description of the theoretical aspects of Gaussian
+Processes for Machine Learning, please refer to the references below:
+
+.. topic:: References:
+
+    * *"DACE, A Matlab Kriging Toolbox"*
+      S Lophaven, HB Nielsen, J Sondergaard
+      2002,
+      <http://www2.imm.dtu.dk/~hbn/dace/>
+
+    * *"Gaussian Processes for Machine Learning"*
+      CE Rasmussen, CKI Williams
+      MIT Press, 2006 (Ed. T Diettrich)
+      <http://www.gaussianprocess.org/gpml/chapters/RW.pdf>
+
+    * *"The design and analysis of computer experiments"*
+      TJ Santner, BJ Williams, W Notz
+      Springer, 2003
+      <http://www.stat.osu.edu/~comp_exp/book.html>
+
+.. correlation_models::
+
+Correlation Models
+==================
+
+Common correlation models matches some famous SVM's kernels because they are mostly built on the equivalent
+assumptions. They must fulfill Mercer's conditions. Note however, that the choice of the correlation model
+should be made in agreement with the known properties of the original experiment from which the observations
+come.
+
+* If the original experiment is known to be infinitely differentiable (smooth), then one should use the *squared-exponential correlation model*.
+* If it's not, then one should rather use the *exponential correlation model*.
+* Note also that there exists a correlation model that takes the degree of derivability as input: this is the Matern correlation model, but it's not implemented here.
+
+For a more detailed discussion on the selection of the appropriate correlation models, dee the book by Rasmussen & Williams in references.
+
+.. regression_models::
+
+Regression Models
+=================
+
+Common linear regression models involve zero (constant), first- and second-order polynomials. But one may
+specify its own in the form of a Python function that takes the features X as input and that returns a vector
+containing the values of the functional set. The only constraint is that the number of functions must not exceed the
+number of available observations so that the underlying regression problem is not *under-determined*.
+
+
+An introductory example
+=======================
+
+Say we want to surrogate the function :math:`g(x) = x \sin(x)`. To do so, the function is evaluated onto a
+design of experiments. Then, we define a GaussianProcessModel whose regression and correlation
+models might be specified using additional kwargs, and ask for the model to be fitted to the data. Depending on the number
+of parameters provided at instanciation, the fitting procedure may recourse to maximum likelihood estimation for the parameters
+or alternatively it uses the given parameters.
+
+
+
+Implementation details
+======================
+
+The present implementation is based on a transliteration of the DACE Matlab
+toolbox.
+
+.. topic:: References:
+
+    * *"DACE, A Matlab Kriging Toolbox"*
+      S Lophaven, HB Nielsen, J Sondergaard
+      2002,
+      <http://www2.imm.dtu.dk/~hbn/dace/>
diff --git a/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py b/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..553ca21f97851e528730d7285954daf18a6ed9f9
--- /dev/null
+++ b/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py
@@ -0,0 +1,113 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+"""
+===============================================
+Gaussian Processes for Machine Learning example
+===============================================
+
+A two-dimensional regression exercise with a post-processing
+allowing for probabilistic classification thanks to the
+Gaussian property of the prediction.
+"""
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com
+# License: BSD style
+
+import numpy as np
+from scipy import stats
+from scikits.learn.gpml import GaussianProcessModel
+from matplotlib import pyplot as pl
+from matplotlib import cm
+from mpl_toolkits.mplot3d import Axes3D
+class FormatFaker(object):
+    def __init__(self, str): self.str = str
+    def __mod__(self, stuff): return self.str
+
+# Standard normal distribution functions
+Grv = stats.distributions.norm()
+phi = lambda x: Grv.pdf(x)
+PHI = lambda x: Grv.cdf(x)
+PHIinv = lambda x: Grv.ppf(x)
+
+# A few constants
+beta0 = 8
+b, kappa, e = 5, .5, .1
+
+# The function to predict (classification will then consist in predicting wheter g(x) <= 0 or not)
+g = lambda x: b - x[:,1] - kappa * ( x[:,0] - e )**2.
+
+# Design of experiments
+X = np.array([[-4.61611719, -6.00099547],
+              [ 4.10469096,  5.32782448],
+              [ 0.00000000, -0.50000000],
+              [-6.17289014, -4.6984743 ],
+              [ 1.3109306 , -6.93271427],
+              [-5.03823144,  3.10584743],
+              [-2.87600388,  6.74310541],
+              [ 5.21301203,  4.26386883]])
+
+# Observations
+Y = g(X)
+
+# Instanciate and fit Gaussian Process Model
+aGaussianProcessModel = GaussianProcessModel(theta0=5e-1)
+
+# Don't perform MLE or you'll get a perfect prediction for this simple example!
+aGaussianProcessModel.fit(X, Y)
+
+# Evaluate real function, the prediction and its MSE on a grid
+res = 50
+x1, x2 = np.meshgrid(np.linspace(-beta0, beta0, res), np.linspace(-beta0, beta0, res))
+xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T
+
+YY = g(xx)
+yy, MSE = aGaussianProcessModel.predict(xx, eval_MSE=True)
+sigma = np.sqrt(MSE)
+yy = yy.reshape((res,res))
+YY = YY.reshape((res,res))
+sigma = sigma.reshape((res,res))
+k = PHIinv(.975)
+
+# Plot the probabilistic classification iso-values using the Gaussian property of the prediction
+fig = pl.figure(1)
+ax = fig.add_subplot(111)
+ax.axes.set_aspect('equal')
+cax = pl.imshow(np.flipud(PHI(-yy/sigma)), cmap=cm.gray_r, alpha=.8, extent=(-beta0,beta0,-beta0,beta0))
+norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9)
+cb = pl.colorbar(cax, ticks=[0.,0.2,0.4,0.6,0.8,1.], norm=norm)
+cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$')
+pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12)
+pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12)
+pl.xticks([])
+pl.yticks([])
+ax.set_xticklabels([])
+ax.set_yticklabels([])
+pl.xlabel('$x_1$')
+pl.ylabel('$x_2$')
+print u'Click to place label: $g(\mathbf{x})=0$'
+cs = pl.contour(x1, x2, YY, [0.], colors='k', linestyles='dashdot')
+pl.clabel(cs,fmt=FormatFaker(u'$g(\mathbf{x})=0$'),fontsize=11,manual=True)
+print u'Click to place label: ${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$'
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', linestyles='solid')
+pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$'),fontsize=11,manual=True)
+print u'Click to place label: $\mu_{\widehat{G}}(\mathbf{x})=0$'
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', linestyles='dashed')
+pl.clabel(cs,fmt=FormatFaker(u'$\mu_{\widehat{G}}(\mathbf{x})=0$'),fontsize=11,manual=True)
+print u'Click to place label: ${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$'
+cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', linestyles='solid')
+pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$'),fontsize=11,manual=True)
+
+# Plot the prediction and the bounds of the 95% confidence interval
+fig = pl.figure(2)
+ax = Axes3D(fig)
+ax.axes.set_aspect('equal')
+ax.plot_surface(x1, x2, yy, linewidth = 0.5, rstride = 1, cstride = 1, color = 'k', alpha = .8)
+ax.plot_surface(x1, x2, yy - k*sigma, linewidth = 0.5, rstride = 1, cstride = 1, color = 'b', alpha = .8)
+ax.plot_surface(x1, x2, yy + k*sigma, linewidth = 0.5, rstride = 1, cstride = 1, color = 'r', alpha = .8)
+ax.scatter3D(X[Y <= 0, 0], X[Y <= 0, 1], Y[Y <= 0], 'r.', s = 20)
+ax.scatter3D(X[Y > 0, 0], X[Y > 0, 1], Y[Y > 0], 'b.', s = 20)
+ax.set_xlabel(u'$x_1$')
+ax.set_ylabel(u'$x_2$')
+ax.set_zlabel(u'$\widehat{G}(x_1,\,x_2)$')
+
+pl.show()
\ No newline at end of file
diff --git a/examples/gpml/plot_gpml_regression.py b/examples/gpml/plot_gpml_regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..a67f35a41fe6677555f2a39fdf1523356bad8bec
--- /dev/null
+++ b/examples/gpml/plot_gpml_regression.py
@@ -0,0 +1,78 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+"""
+===============================================
+Gaussian Processes for Machine Learning example
+===============================================
+
+A simple one-dimensional regression exercise with
+different correlation models and maximum likelihood
+estimation of the Gaussian Process Model parameters.
+"""
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com
+# License: BSD style
+
+import numpy as np
+from scipy import stats
+from scikits.learn.gpml import GaussianProcessModel, correxp1, correxp2, corrlin, corrcubic
+from matplotlib import pyplot as pl
+
+# The function to predict
+f = lambda x: x*np.sin(x)
+
+# The design of experiments
+X = np.array([1., 3., 5., 6., 7., 8.])
+
+# Observations
+Y = f(X)
+
+# Mesh the input space for evaluations of the real function, the prediction and its MSE
+x = np.linspace(0,10,1000)
+
+# Loop correlation models
+corrs = (correxp1, correxp2, corrlin, corrcubic)
+colors = ('b', 'g', 'y', 'm')
+for k in range(len(corrs)):
+    
+    # Instanciate a Gaussian Process Model with the k-th correlation model
+    if corrs[k] == corrlin or corrs[k] == corrcubic:
+        aGaussianProcessModel = GaussianProcessModel(corr=corrs[k], theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=100)
+    else:
+        aGaussianProcessModel = GaussianProcessModel(corr=corrs[k], theta0=1e-2, thetaL=1e-4, thetaU=1e+1, random_start=100)
+    
+    # Fit to data using Maximum Likelihood Estimation of the parameters
+    aGaussianProcessModel.fit(X, Y)
+    
+    # Make the prediction on the meshed x-axis (ask for MSE as well)
+    y, MSE = aGaussianProcessModel.predict(x, eval_MSE=True)
+    sigma = np.sqrt(MSE)
+    
+    # Compute the score function on a grid of the autocorrelation parameter space
+    theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL), np.log10(aGaussianProcessModel.thetaU), 100)
+    score_values = []
+    for t in theta_values:
+        score_values.append(aGaussianProcessModel.score(theta=t)[0])
+    
+    fig = pl.figure()
+    
+    # Plot the function, the prediction and the 95% confidence interval based on the MSE
+    ax = fig.add_subplot(211)
+    pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
+    pl.plot(X, Y, 'r.', markersize=10, label=u'Observations')
+    pl.plot(x, y, colors[k]+'-', label=u'Prediction (%s)' % corrs[k].__name__)
+    pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc=colors[k], ec='None', label='95% confidence interval')
+    pl.xlabel('$x$')
+    pl.ylabel('$f(x)$')
+    pl.ylim(-10, 20)
+    pl.legend(loc='upper left')
+    
+    # Plot the score function
+    ax = fig.add_subplot(212)
+    pl.plot(theta_values, score_values, colors[k]+'-')
+    pl.xlabel(u'$\\theta$')
+    pl.ylabel(u'Score')
+    pl.xscale('log')
+    pl.xlim(aGaussianProcessModel.thetaL[0], aGaussianProcessModel.thetaU[0])
+
+pl.show()
diff --git a/scikits/learn/gpml/__init__.py b/scikits/learn/gpml/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcba115a3744cd3478e15a22bfdd14285266bfe9
--- /dev/null
+++ b/scikits/learn/gpml/__init__.py
@@ -0,0 +1,14 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+"""
+    This module contains a contribution to the scikit-learn project that
+    implements Gaussian Process based prediction (also known as Kriging).
+    
+    The present implementation is based on a transliteration of the DACE
+    Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>.
+"""
+
+from .gaussian_process_model import GaussianProcessModel
+from .correlation_models import *
+from .regression_models import *
\ No newline at end of file
diff --git a/scikits/learn/gpml/correlation_models.py b/scikits/learn/gpml/correlation_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d8235f22bf695cdf8a28e6252b609f19afa525
--- /dev/null
+++ b/scikits/learn/gpml/correlation_models.py
@@ -0,0 +1,239 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+################
+# Dependencies #
+################
+
+import numpy as np
+
+#############################
+# Defaut correlation models #
+#############################
+
+def correxp1(theta,d):
+    """
+    Exponential autocorrelation model.
+    (Ornstein-Uhlenbeck stochastic process)
+                                                   n
+    correxp1 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i| )
+                                                 i = 1
+    Parameters
+    ----------
+    theta : array_like
+        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) containing the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    if d.ndim > 1:
+        n_features = d.shape[1]
+    else:
+        n_features = 1
+    if theta.size == 1:
+        theta = np.repeat(theta, n_features)
+    elif theta.size != n_features:
+        raise ValueError, "Length of theta must be 1 or "+str(n_features)
+    
+    td = - theta.reshape(1, n_features) * abs(d)
+    r = np.exp(np.sum(td,1))
+    
+    return r
+
+def correxp2(theta,d):
+    """
+    Squared exponential correlation model (Radial Basis Function).
+    (Infinitely differentiable stochastic process, very smooth)
+                                                   n
+    correxp2 : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * (dx_i)^2 )
+                                                 i = 1
+    Parameters
+    ----------
+    theta : array_like
+        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) containing the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    if d.ndim > 1:
+        n_features = d.shape[1]
+    else:
+        n_features = 1
+    if theta.size == 1:
+        theta = np.repeat(theta, n_features)
+    elif theta.size != n_features:
+        raise Exception, "Length of theta must be 1 or "+str(n_features)
+    
+    td = - theta.reshape(1, n_features) * d**2
+    r = np.exp(np.sum(td,1))
+    
+    return r
+
+def correxpg(theta,d):
+    """
+    Generalized exponential correlation model.
+    (Useful when one does not know the smoothness of the function to be predicted.)
+                                                   n
+    correxpg : theta, dx --> r(theta, dx) = exp(  sum  - theta_i * |dx_i|^p )
+                                                 i = 1
+    Parameters
+    ----------
+    theta : array_like
+        An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the autocorrelation parameter(s) (theta, p).
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) with the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    if d.ndim > 1:
+        n_features = d.shape[1]
+    else:
+        n_features = 1
+    lth = theta.size
+    if n_features > 1 and lth == 2:
+        theta = np.hstack([np.repeat(theta[0], n_features), theta[1]])
+    elif lth != n_features+1:
+        raise Exception, "Length of theta must be 2 or "+str(n_features+1)
+    else:
+        theta = theta.reshape(1, lth)
+    
+    td = - theta[:,0:-1].reshape(1, n_features) * abs(d)**theta[:,-1]
+    r = np.exp(np.sum(td,1))
+    
+    return r
+
+def corriid(theta,d):
+    """
+    Spatial independence correlation model (pure nugget).
+    (Useful when one wants to solve an ordinary least squares problem!)
+                                                    n
+    corriid : theta, dx --> r(theta, dx) = 1 if   sum |dx_i| == 0 and 0 otherwise
+                                                  i = 1
+    Parameters
+    ----------
+    theta : array_like
+        None.
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) with the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    n_eval = d.shape[0]
+    r = np.zeros(n_eval)
+    # The ones on the diagonal of the correlation matrix are enforced within
+    # the KrigingModel instanciation to allow multiple design sites in this
+    # ordinary least squares context.
+    
+    return r
+
+def corrcubic(theta,d):
+    """
+    Cubic correlation model.
+                                               n
+    corrcubic : theta, dx --> r(theta, dx) = prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) ,  i = 1,...,m
+                                             j = 1
+    Parameters
+    ----------
+    theta : array_like
+        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) with the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    if d.ndim > 1:
+        n_features = d.shape[1]
+    else:
+        n_features = 1
+    lth = theta.size
+    if  lth == 1:
+        theta = np.repeat(theta, n_features)[np.newaxis,:]
+    elif lth != n_features:
+        raise Exception, "Length of theta must be 1 or "+str(n_features)
+    else:
+        theta = theta.reshape(1, n_features)
+    
+    td = abs(d) * theta
+    td[td > 1.] = 1.
+    ss = 1. - td**2. * (3. - 2.*td)
+    r = np.prod(ss,1)
+    
+    return r
+
+def corrlin(theta,d):
+    """
+    Linear correlation model.
+                                             n
+    corrlin : theta, dx --> r(theta, dx) = prod max(0, 1 - theta_j*d_ij) ,  i = 1,...,m
+                                           j = 1
+    Parameters
+    ----------
+    theta : array_like
+        An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s).
+    dx : array_like
+        An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated.
+    
+    Returns
+    -------
+    r : array_like
+        An array with shape (n_eval, ) with the values of the autocorrelation model.
+    """
+    
+    theta = np.asanyarray(theta, dtype=np.float)
+    d = np.asanyarray(d, dtype=np.float)
+    
+    if d.ndim > 1:
+        n_features = d.shape[1]
+    else:
+        n_features = 1
+    lth = theta.size
+    if  lth == 1:
+        theta = np.repeat(theta, n_features)[np.newaxis,:]
+    elif lth != n_features:
+        raise Exception, "Length of theta must be 1 or "+str(n_features)
+    else:
+        theta = theta.reshape(1, n_features)
+    
+    td = abs(d) * theta
+    td[td > 1.] = 1.
+    ss = 1. - td
+    r = np.prod(ss,1)
+        
+    return r
\ No newline at end of file
diff --git a/scikits/learn/gpml/gaussian_process_model.py b/scikits/learn/gpml/gaussian_process_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..368dbbb40d0d8fdc9a3d3fec084eba02c1b6fe69
--- /dev/null
+++ b/scikits/learn/gpml/gaussian_process_model.py
@@ -0,0 +1,653 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+################
+# Dependencies #
+################
+
+import numpy as np
+from scipy import linalg, optimize, random
+from ..base import BaseEstimator
+from .regression_models import regpoly0
+from .correlation_models import correxp2
+machine_epsilon = np.finfo(np.double).eps
+
+def col_sum(x):
+    """
+    Performs columnwise sums of elements in x depending on its shape.
+    
+    Parameters
+    ----------
+    x : array_like
+        An array of size (p, q).
+    
+    Returns
+    -------
+    s : array_like
+        An array of size (q, ) which contains the columnwise sums of the elements in x.
+    """
+    
+    x = np.asanyarray(x, dtype=np.float)
+    
+    if x.ndim > 1:
+        s = x.sum(axis=0)
+    else:
+        s = x
+    
+    return s
+
+####################################
+# The Gaussian Process Model class #
+####################################
+
+class GaussianProcessModel(BaseEstimator):
+    """
+    A class that implements scalar Gaussian Process based prediction (also known as Kriging).
+
+    Example
+    -------
+    import numpy as np
+    from scikits.learn.gpml import GaussianProcessModel
+    import pylab as pl
+
+    f = lambda x: x*np.sin(x)
+    X = np.array([1., 3., 5., 6., 7., 8.])
+    Y = f(X)
+    aGaussianProcessModel = GaussianProcessModel(regr=regpoly0, corr=correxp2, theta0=1e-1, thetaL=1e-3, thetaU=1e0, random_start=100)
+    aGaussianProcessModel.fit(X, Y)
+
+    pl.figure(1)
+    x = np.linspace(0,10,1000)
+    y, MSE = aGaussianProcessModel.predict(x, eval_MSE=True)
+    sigma = np.sqrt(MSE)
+    pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
+    pl.plot(X, Y, 'r.', markersize=10, label=u'Observations')
+    pl.plot(x, y, 'k-', label=u'$\widehat{f}(x) = {\\rm BLUP}(x)$')
+    pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label=u'95\% confidence interval')
+    pl.xlabel('$x$')
+    pl.ylabel('$f(x)$')
+    pl.legend(loc='upper left')
+
+    pl.figure(2)
+    theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL),np.log10(aGaussianProcessModel.thetaU),100)
+    score_values = []
+    for t in theta_values:
+            score_values.append(aGaussianProcessModel.score(theta=t)[0])
+    pl.plot(theta_values, score_values)
+    pl.xlabel(u'$\\theta$')
+    pl.ylabel(u'Score')
+    pl.xscale('log')
+
+    pl.show()
+
+    Methods
+    -------
+    fit(X, y) : self
+        Fit the model.
+
+    predict(X) : array
+        Predict using the model.
+
+    Todo
+    ----
+    o Add the 'sparse' storage mode for which the correlation matrix is stored in its sparse eigen decomposition format instead of its full Cholesky decomposition.
+
+    Implementation details
+    ----------------------
+    The presentation implementation is based on a translation of the DACE Matlab toolbox.
+    
+    See references:
+    [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). DACE - A MATLAB Kriging Toolbox.
+        http://www2.imm.dtu.dk/~hbn/dace/dace.pdf
+    """
+    
+    def __init__(self, regr = regpoly0, corr = correxp2, beta0 = None, storage_mode = 'full', verbose = True, theta0 = 1e-1, thetaL = None, thetaU = None, optimizer = 'fmin_cobyla', random_start = 1, normalize = True, nugget = 10.*machine_epsilon):
+        """
+        The Gaussian Process Model constructor.
+
+        Parameters
+        ----------
+        regr : lambda function, optional
+            A regression function returning an array of outputs of the linear regression functional basis.
+            (The number of observations m should be greater than the size p of this basis)
+            Default assumes a simple constant regression trend (see regpoly0).
+            
+        corr : lambda function, optional
+            A stationary autocorrelation function returning the autocorrelation between two points x and x'.
+            Default assumes a squared-exponential autocorrelation model (see correxp2).
+            
+        beta0 : double array_like, optional
+            Default assumes Universal Kriging (UK) so that the vector beta of regression weights is estimated
+            by Maximum Likelihood. Specifying beta0 overrides estimation and performs Ordinary Kriging (OK).
+            
+        storage_mode : string, optional
+            A string specifying whether the Cholesky decomposition of the correlation matrix should be
+            stored in the class (storage_mode = 'full') or not (storage_mode = 'light').
+            Default assumes storage_mode = 'full', so that the Cholesky decomposition of the correlation matrix is stored.
+            This might be a useful parameter when one is not interested in the MSE and only plan to estimate the BLUP,
+            for which the correlation matrix is not required.
+            
+        verbose : boolean, optional
+            A boolean specifying the verbose level.
+            Default is verbose = True.
+
+        theta0 : double array_like, optional
+            An array with shape (n_features, ) or (1, ).
+            The parameters in the autocorrelation model.
+            If thetaL and thetaU are also specified, theta0 is considered as the starting point
+            for the Maximum Likelihood Estimation of the best set of parameters.
+            Default assumes isotropic autocorrelation model with theta0 = 1e-1.
+            
+        thetaL : double array_like, optional
+            An array with shape matching theta0's.
+            Lower bound on the autocorrelation parameters for Maximum Likelihood Estimation.
+            Default is None, skips Maximum Likelihood Estimation and uses theta0.
+            
+        thetaU : double array_like, optional
+            An array with shape matching theta0's.
+            Upper bound on the autocorrelation parameters for Maximum Likelihood Estimation.
+            Default is None and skips Maximum Likelihood Estimation and uses theta0.
+            
+        normalize : boolean, optional
+            Design sites X and observations y are centered and reduced wrt means and standard deviations
+            estimated from the n_samples observations provided.
+            Default is true so that data is normalized to ease MLE.
+            
+        nugget : double, optional
+            Introduce a nugget effect to allow smooth predictions from noisy data.
+            Default assumes a nugget close to machine precision for the sake of robustness (nugget = 10.*machine_epsilon).
+            
+        optimizer : string, optional
+            A string specifying the optimization algorithm to be used ('fmin_cobyla' is the sole algorithm implemented yet).
+            Default uses 'fmin_cobyla' algorithm from scipy.optimize.
+            
+        random_start : int, optional
+            The number of times the Maximum Likelihood Estimation should be performed from a random starting point.
+            The first MLE always uses the specified starting point (theta0),
+            the next starting points are picked at random according to an exponential distribution (log-uniform in [thetaL, thetaU]).
+            Default does not use random starting point (random_start = 1).
+
+        Returns
+        -------
+        aGaussianProcessModel : self
+            A Gaussian Process Model object awaiting data to be fitted to.
+        """
+        
+        self.regr = regr
+        self.corr = corr
+        self.beta0 = beta0
+        
+        # Check storage mode
+        if storage_mode != 'full' and storage_mode != 'light':
+            if storage_mode == 'sparse':
+                raise NotImplementedError, "The 'sparse' storage mode is not supported yet. Please contribute!"
+            else:
+                raise ValueError, "Storage mode should either be 'full' or 'light'. Unknown storage mode: "+str(storage_mode)
+        else:
+            self.storage_mode = storage_mode
+        
+        self.verbose = verbose
+        
+        # Check correlation parameters
+        self.theta0 = np.atleast_2d(np.asanyarray(theta0, dtype=np.float))
+        self.thetaL, self.thetaU = thetaL, thetaU
+        lth = self.theta0.size
+        
+        if self.thetaL is not None and self.thetaU is not None:
+            # Parameters optimization case
+            self.thetaL = np.atleast_2d(np.asanyarray(thetaL, dtype=np.float))
+            self.thetaU = np.atleast_2d(np.asanyarray(thetaU, dtype=np.float))
+            
+            if self.thetaL.size != lth or self.thetaU.size != lth:
+                raise ValueError, "theta0, thetaL and thetaU must have the same length"
+            if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL):
+                raise ValueError, "The bounds must satisfy O < thetaL <= thetaU"
+            
+        elif self.thetaL is None and self.thetaU is None:
+            # Given parameters case
+            if np.any(self.theta0 <= 0):
+                raise ValueError, "theta0 must be strictly positive"
+            
+        elif self.thetaL is None or self.thetaU is None:
+            # Error
+            raise Exception, "thetaL and thetaU should either be both or neither specified"
+        
+        # Store other parameters
+        self.normalize = normalize
+        self.nugget = nugget
+        self.optimizer = optimizer
+        self.random_start = int(random_start)
+        
+    def fit(self, X, y):
+        """
+        The Gaussian Process Model fitting method.
+
+        Parameters
+        ----------
+        X : double array_like
+            An array with shape (n_samples, n_features) with the design sites at which observations were made.
+            
+        y : double array_like
+            An array with shape (n_features, ) with the observations of the scalar output to be predicted.
+            
+        Returns
+        -------
+        aGaussianProcessModel : self
+            A fitted Gaussian Process Model object awaiting data to perform predictions.
+        """
+        
+        # Force data to numpy.array type (from coding guidelines)
+        X = np.asanyarray(X, dtype=np.float)
+        y = np.asanyarray(y, dtype=np.float)
+        
+        # Check design sites & observations
+        n_samples_X = X.shape[0]
+        if X.ndim > 1:
+            n_features = X.shape[1]
+        else:
+            n_features = 1
+        X = X.reshape(n_samples_X, n_features)
+        
+        n_samples_y = y.shape[0]
+        if y.ndim > 1:
+            raise NotImplementedError, "y has more than one dimension. This is not supported yet (scalar output prediction only). Please contribute!"
+        y = y.reshape(n_samples_y, 1)
+        
+        if n_samples_X != n_samples_y:
+            raise Exception, "X and y must have the same number of rows!"
+        else:
+            n_samples = n_samples_X
+        
+        # Normalize data or don't
+        if self.normalize:
+            mean_X = np.mean(X, axis=0)
+            std_X = np.sqrt(1./(n_samples-1.)*np.sum((X - np.mean(X,0))**2.,0)) #np.std(y, axis=0)
+            mean_y = np.mean(y, axis=0)
+            std_y = np.sqrt(1./(n_samples-1.)*np.sum((y - np.mean(y,0))**2.,0)) #np.std(y, axis=0)
+            std_X[std_X == 0.] = 1.
+            std_y[std_y == 0.] = 1.
+        else:
+            mean_X = np.array([0.])
+            std_X = np.array([1.])
+            mean_y = np.array([0.])
+            std_y = np.array([1.])
+        
+        X_ = (X - mean_X) / std_X
+        y_ = (y - mean_y) / std_y
+        
+        # Calculate matrix of distances D between samples
+        mzmax = n_samples*(n_samples-1)/2
+        ij = np.zeros([mzmax, 2])
+        D = np.zeros([mzmax, n_features])
+        ll = np.array([-1])
+        for k in range(n_samples-1):
+            ll = ll[-1] + 1 + range(n_samples-k-1)
+            ij[ll,:] = np.concatenate([[np.repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T
+            D[ll,:] = X_[k,:] - X_[(k+1):n_samples,:]
+        if (np.min(np.sum(np.abs(D),1)) == 0.) and (self.corr != corriid):
+            raise Exception, "Multiple X are not allowed"
+        
+        # Regression matrix and parameters
+        F = self.regr(X)
+        n_samples_F = F.shape[0]
+        if F.ndim > 1:
+            p = F.shape[1]
+        else:
+            p = 1
+        if n_samples_F != n_samples:
+            raise Exception, "Number of rows in F and X do not match. Most likely something is wrong in the regression model."
+        if p > n_samples_F:
+            raise Exception, "Ordinary least squares problem is undetermined. n_samples=%d must be greater than the regression model size p=%d!" % (n_samples, p)
+        if self.beta0 is not None and (self.beta0.shape[0] != p or self.beta0.ndim > 1):
+            raise Exception, "Shapes of beta0 and F do not match."
+    
+        # Set attributes
+        self.X = X
+        self.y = y
+        self.X_ = X_
+        self.y_ = y_
+        self.D = D
+        self.ij = ij
+        self.F = F
+        self.X_sc = np.concatenate([[mean_X],[std_X]])
+        self.y_sc = np.concatenate([[mean_y],[std_y]])
+        
+        # Determine Gaussian Process Model parameters
+        if self.thetaL is not None and self.thetaU is not None:
+            # Maximum Likelihood Estimation of the parameters
+            if self.verbose:
+                print "Performing Maximum Likelihood Estimation of the autocorrelation parameters..."
+            self.theta, self.score_value, self.par  = self.__arg_max_score__()
+            if np.isinf(self.score_value) :
+                raise Exception, "Bad parameter region. Try increasing upper bound"
+        else:
+            # Given parameters
+            if self.verbose:
+                print "Given autocorrelation parameters. Computing Gaussian Process Model parameters..."
+            self.theta = self.theta0
+            self.score_value, self.par = self.score()
+            if np.isinf(self.score_value):
+                raise Exception, "Bad point. Try increasing theta0"
+        
+        if self.storage_mode == 'light':
+            # Delete heavy data (it will be computed again if required)
+            # (it is required only when MSE is wanted in self.predict)
+            if self.verbose:
+                print "Light storage mode specified. Flushing autocorrelation matrix..."
+            self.D = None
+            self.ij = None
+            self.F = None
+            self.par['C'] = None
+            self.par['Ft'] = None
+            self.par['G'] = None
+        
+        return self
+        
+    def predict(self, X, eval_MSE=False, batch_size=None):
+        """
+        This function evaluates the Gaussian Process Model at x.
+        
+        Parameters
+        ----------
+        X : array_like
+            An array with shape (n_eval, n_features) giving the point(s) at which the prediction(s) should be made.
+        eval_MSE : boolean, optional
+            A boolean specifying whether the Mean Squared Error should be evaluated or not.
+            Default assumes evalMSE = False and evaluates only the BLUP (mean prediction).
+        verbose : boolean, optional
+            A boolean specifying the verbose level.
+            Default is verbose = True.
+        batch_size : integer, optional
+            An integer giving the maximum number of points that can be evaluated simulatneously (depending on the available memory).
+            Default is None so that all given points are evaluated at the same time.
+        
+        Returns
+        -------
+        y : array_like
+            An array with shape (n_eval, ) with the Best Linear Unbiased Prediction at x.
+        MSE : array_like, optional (if eval_MSE == True)
+            An array with shape (n_eval, ) with the Mean Squared Error at x.
+        """
+        
+        # Check
+        if np.any(np.isnan(self.par['beta'])):
+            raise Exception, "Not a valid GaussianProcessModel. Try fitting it again with different parameters theta"
+        
+        # Check design & trial sites
+        X = np.asanyarray(X, dtype=np.float)
+        n_samples = self.X_.shape[0]
+        if self.X_.ndim > 1:
+            n_features = self.X_.shape[1]
+        else:
+            n = 1
+        n_eval = X.shape[0]
+        if X.ndim > 1:
+            n_features_X = X.shape[1]
+        else:
+            n_features_X = 1
+        X = X.reshape(n_eval, n_features_X)
+        
+        if n_features_X != n_features:
+            raise ValueError, "The number of features in X (X.shape[1] = %d) should match the sample size used for fit() which is %d." % (n_features_X, n_features)
+        
+        if batch_size is None: # No memory management (evaluates all given points in a single batch run)
+            
+            # Normalize trial sites
+            X_ = (X - self.X_sc[0,:]) / self.X_sc[1,:]
+            
+            # Initialize output
+            y = np.zeros(n_eval)
+            if eval_MSE:
+                MSE = np.zeros(n_eval)
+            
+            # Get distances to design sites
+            dx = np.zeros([n_eval*n_samples,n_features])
+            kk = np.arange(n_samples).astype(int)
+            for k in range(n_eval):
+                dx[kk,:] = X_[k,:] - self.X_
+                kk = kk + n_samples
+            
+            # Get regression function and correlation
+            f = self.regr(X_)
+            r = self.corr(self.theta, dx).reshape(n_eval, n_samples).T
+            
+            # Scaled predictor
+            y_ = np.matrix(f) * np.matrix(self.par['beta']) + (np.matrix(self.par['gamma']) * np.matrix(r)).T
+            
+            # Predictor
+            y = (self.y_sc[0,:] + self.y_sc[1,:] * np.array(y_)).reshape(n_eval)
+            
+            # Mean Squared Error
+            if eval_MSE:
+                par = self.par
+                if par['C'] is None:
+                    # Light storage mode (need to recompute C, F, Ft and G)
+                    if self.verbose:
+                        print "This GaussianProcessModel used light storage mode at instanciation. Need to recompute autocorrelation matrix..."
+                    score_value, par = self.score()
+                
+                rt = linalg.solve(np.matrix(par['C']), np.matrix(r))
+                if self.beta0 is None:
+                    # Universal Kriging
+                    u = linalg.solve(np.matrix(-self.par['G'].T), np.matrix(self.par['Ft']).T * np.matrix(rt) - np.matrix(f).T)
+                else:
+                    # Ordinary Kriging
+                    u = 0. * y
+                
+                MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt)**2.) + col_sum(np.array(u)**2.)).T
+                
+                # Mean Squared Error might be slightly negative depending on machine precision
+                # Force to zero!
+                MSE[MSE < 0.] = 0.
+                
+                return y, MSE
+                
+            else:
+                
+                return y
+            
+        else: # Memory management
+            
+            if type(batch_size) is not int or batch_size <= 0:
+                raise Exception, "batch_size must be a positive integer"
+            
+            if eval_MSE:
+                
+                y, MSE = np.zeros(n_eval), np.zeros(n_eval)
+                for k in range(n_eval/batch_size):
+                    y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])], MSE[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \
+                        self.predict(X[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None)
+                
+                return y, MSE
+                
+            else:
+                
+                y = np.zeros(n_eval)
+                for k in range(n_eval/batch_size):
+                    y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \
+                        self.__call__(x[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None)
+                
+                return y
+    
+    
+    
+    def score(self, theta = None):
+        """
+        This function determines the BLUP parameters and evaluates the score function for the given autocorrelation parameters theta.
+        Maximizing this score function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the
+        assumed joint Gaussian distribution of the observations y evaluated onto the design of experiments X.
+        
+        Parameters
+        ----------
+        theta : array_like, optional
+            An array containing the autocorrelation parameters at which the Gaussian Process Model parameters should be determined.
+            Default uses the built-in autocorrelation parameters (ie theta = self.theta).
+        
+        Returns
+        -------
+        par : dict
+            A dictionary containing the requested Gaussian Process Model parameters:
+            par['sigma2'] : Gaussian Process variance.
+            par['beta'] : Generalized least-squares regression weights for Universal Kriging or given beta0 for Ordinary Kriging.
+            par['gamma'] : Gaussian Process weights.
+            par['C'] : Cholesky decomposition of the correlation matrix [R].
+            par['Ft'] : Solution of the linear equation system : [R] x Ft = F
+            par['G'] : QR decomposition of the matrix Ft.
+            par['detR'] : Determinant of the correlation matrix raised at power 1/n_samples.
+        score_value : double
+            The value of the score function associated to the given autocorrelation parameters theta.
+        """
+        
+        if theta is None:
+            # Use built-in autocorrelation parameters
+            theta = self.theta
+        
+        score_value = -np.inf
+        par = { }
+        
+        # Retrieve data
+        n_samples = self.X_.shape[0]
+        D = self.D
+        ij = self.ij
+        F = self.F
+        
+        if D is None:
+            # Light storage mode (need to recompute D, ij and F)
+            if self.X_.ndim > 1:
+                n_features = self.X_.shape[1]
+            else:
+                n_features = 1
+            mzmax = n_samples*(n_samples-1)/2
+            ij = zeros([mzmax, n_features])
+            D = zeros([mzmax, n_features])
+            ll = array([-1])
+            for k in range(n_samples-1):
+                ll = ll[-1]+1 + range(n_samples-k-1)
+                ij[ll,:] = np.concatenate([[repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T
+                D[ll,:] = self.X_[k,:] - self.X_[(k+1):n_samples,:]
+            if (min(sum(abs(D),1)) == 0.) and (self.corr != corriid):
+                raise Exception, "Multiple X are not allowed"
+            F = self.regr(self.X_)
+        
+        # Set up R
+        r = self.corr(theta, D)
+        R = np.eye(n_samples) * (1. + self.nugget)
+        R[ij.astype(int)[:,0], ij.astype(int)[:,1]] = r
+        R[ij.astype(int)[:,1], ij.astype(int)[:,0]] = r
+        
+        # Cholesky decomposition of R
+        try:
+            C = linalg.cholesky(R, lower=True)
+        except linalg.LinAlgError:
+            return score_value, par
+        
+        # Get generalized least squares solution
+        Ft = linalg.solve(C, F)
+        try:
+            Q, G = linalg.qr(Ft, econ=True)
+        except:
+            #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177: DeprecationWarning: qr econ argument will be removed after scipy 0.7. The economy transform will then be available through the mode='economic' argument.
+            Q, G = linalg.qr(Ft, mode='economic')
+            pass
+        
+        rcondG = 1./(linalg.norm(G)*linalg.norm(linalg.inv(G)))
+        if rcondG < 1e-10:
+            # Check F
+            condF = linalg.norm(F)*linalg.norm(linalg.inv(F))
+            if condF > 1e15:
+                raise Exception, "F is too ill conditioned. Poor combination of regression model and observations."
+            else:
+                # Ft is too ill conditioned
+                return score_value, par
+        
+        Yt = linalg.solve(C,self.y_)
+        if self.beta0 is None:
+            # Universal Kriging
+            beta = linalg.solve(G, np.matrix(Q).T*np.matrix(Yt))
+        else:
+            # Ordinary Kriging
+            beta = np.array(self.beta0)
+        
+        rho = np.matrix(Yt) - np.matrix(Ft) * np.matrix(beta)
+        normalized_sigma2 = (np.array(rho)**2.).sum(axis=0)/n_samples
+        # The determinant of R is equal to the squared product of the diagonal elements of its Cholesky decomposition C
+        detR = (np.array(np.diag(C))**(2./n_samples)).prod()
+        score_value = - normalized_sigma2.sum() * detR
+        par = { 'sigma2':normalized_sigma2 * self.y_sc[1]**2, \
+                'beta':beta, \
+                'gamma':linalg.solve(C.T,rho).T, \
+                'C':C, \
+                'Ft':Ft, \
+                'G':G }
+        
+        return score_value, par
+
+    def __arg_max_score__(self):
+        """
+        This function estimates the autocorrelation parameters theta as the maximizer of the score function.
+        (Minimization of the opposite score function is used for convenience)
+        
+        Parameters
+        ----------
+        self: All parameters are stored in the Gaussian Process Model object.
+        
+        Returns
+        -------
+        optimal_theta : array_like
+            The best set of autocorrelation parameters (the maximizer of the score function).
+        optimal_score_value : double
+            The optimal score function value.
+        optimal_par : dict
+            The BLUP parameters associated to thetaOpt.
+        """
+        
+        if self.verbose:
+            print "The chosen optimizer is: "+str(self.optimizer)
+            if self.random_start > 1:
+                print str(self.random_start)+" random starts are required."
+        
+        percent_completed = 0.
+        
+        if self.optimizer == 'fmin_cobyla':
+            
+            minus_score = lambda log10t: - self.score(theta = 10.**log10t)[0]
+            
+            constraints = []
+            for i in range(self.theta0.size):
+                constraints.append(lambda log10t: log10t[i] - np.log10(self.thetaL[0,i]))
+                constraints.append(lambda log10t: np.log10(self.thetaU[0,i]) - log10t[i])
+            
+            for k in range(self.random_start):
+                
+                if k == 0:
+                    # Use specified starting point as first guess
+                    theta0 = self.theta0
+                else:
+                    # Generate a random starting point log10-uniformly distributed between bounds
+                    log10theta0 = np.log10(self.thetaL) + random.rand(self.theta0.size).reshape(self.theta0.shape) * np.log10(self.thetaU/self.thetaL)
+                    theta0 = 10.**log10theta0
+                
+                log10_optimal_theta = optimize.fmin_cobyla(minus_score, np.log10(theta0), constraints, iprint=0)
+                optimal_theta = 10.**log10_optimal_theta
+                optimal_minus_score_value, optimal_par = self.score(theta = optimal_theta)
+                optimal_score_value = - optimal_minus_score_value
+                
+                if k > 0:
+                    if optimal_score_value > best_optimal_score_value:
+                        best_optimal_score_value = optimal_score_value
+                        best_optimal_par = optimal_par
+                        best_optimal_theta = optimal_theta
+                else:
+                    best_optimal_score_value = optimal_score_value
+                    best_optimal_par = optimal_par
+                    best_optimal_theta = optimal_theta
+                if self.verbose and self.random_start > 1:
+                    if (20*k)/self.random_start > percent_completed:
+                        percent_completed = (20*k)/self.random_start
+                        print str(5*percent_completed)+"% completed"
+    
+        else:
+            
+            raise NotImplementedError, "This optimizer ('%s') is not implemented yet. Please contribute!" % self.optimizer
+    
+        return best_optimal_theta, best_optimal_score_value, best_optimal_par
\ No newline at end of file
diff --git a/scikits/learn/gpml/regression_models.py b/scikits/learn/gpml/regression_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf8df7d3255b80f998560d6d27608b73ed99392
--- /dev/null
+++ b/scikits/learn/gpml/regression_models.py
@@ -0,0 +1,84 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+################
+# Dependencies #
+################
+
+import numpy as np
+
+############################
+# Defaut regression models #
+############################
+
+def regpoly0(x):
+    """
+    Zero order polynomial (constant, p = 1) regression model.
+    
+    regpoly0 : x --> f(x) = 1
+    
+    Parameters
+    ----------
+    x : array_like
+        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
+    
+    Returns
+    -------
+    f : array_like
+        An array with shape (n_eval, p) with the values of the regression model.
+    """
+    
+    x = np.asanyarray(x, dtype=np.float)
+    n_eval = x.shape[0]
+    f = np.ones([n_eval,1])
+    
+    return f
+
+def regpoly1(x):
+    """
+    First order polynomial (hyperplane, p = n) regression model.
+    
+    regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T
+    
+    Parameters
+    ----------
+    x : array_like
+        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
+    
+    Returns
+    -------
+    f : array_like
+        An array with shape (n_eval, p) with the values of the regression model.
+    """
+    
+    x = np.asanyarray(x, dtype=np.float)
+    n_eval = x.shape[0]
+    f = np.hstack([np.ones([n_eval,1]), x])
+    
+    return f
+
+def regpoly2(x):
+    """
+    Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model.
+    
+    regpoly2 : x --> f(x) = [ x_i*x_j,  (i,j) = 1,...,n ].T
+                                            i > j
+    
+    Parameters
+    ----------
+    x : array_like
+        An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated.
+    
+    Returns
+    -------
+    f : array_like
+        An array with shape (n_eval, p) with the values of the regression model.
+    """
+    
+    x = np.asanyarray(x, dtype=np.float)
+    n_eval, n_features = x.shape
+    f = np.hstack([np.ones([n_eval,1]), x])
+    for  k in range(n_features):
+        f = np.hstack([f, x[:,k,np.newaxis] * x[:,k:]])
+    
+    return f
\ No newline at end of file
diff --git a/scikits/learn/tests/test_gpml.py b/scikits/learn/tests/test_gpml.py
new file mode 100644
index 0000000000000000000000000000000000000000..15453dd8eee7a641eca47a70e1b0b650911959b2
--- /dev/null
+++ b/scikits/learn/tests/test_gpml.py
@@ -0,0 +1,49 @@
+"""
+Testing for Gaussian Process for Machine Learning module (scikits.learn.gpml)
+"""
+
+import numpy as np
+from numpy.testing import assert_array_equal, assert_array_almost_equal, \
+                          assert_almost_equal, assert_raises, assert_
+
+from .. import gpml, datasets, cross_val, metrics
+
+diabetes = datasets.load_diabetes()
+
+def test_regression_1d_x_sinx():
+    """
+    MLE estimation of a Gaussian Process model with a squared exponential
+    correlation model (correxp2). Check random start optimization.
+    
+    Test the interpolating property.
+    """
+    
+    f = lambda x: x * np.sin(x)
+    X = np.array([1., 3., 5., 6., 7., 8.])
+    y = f(X)
+    gpm = gpml.GaussianProcessModel(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=10, verbose=False).fit(X, y)
+    y_pred, MSE = gpm.predict(X, eval_MSE=True)
+    
+    assert (np.all(np.abs((y_pred - y) / y)  < 1e-6) and np.all(MSE  < 1e-6))
+
+def test_regression_diabetes():
+    """
+    MLE estimation of a Gaussian Process model with an anisotropic squared exponential
+    correlation model.
+    
+    Test the model using cross-validation module (quite time-consuming).
+    
+    Poor performance: Leave-one-out estimate of explained variance is about 0.5
+    at best... To be investigated!
+    TODO: find a dataset that would prove GPML performance!
+    """
+    
+    X, y = diabetes['data'], diabetes['target']
+    
+    gpm = gpml.GaussianProcessModel(corr=gpml.correxp2, theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y)
+
+    gpm.thetaL, gpm.thetaU = None, None
+    score_func = lambda self, X_test, y_test: self.predict(X_test)[0]
+    y_pred = cross_val.cross_val_score(gpm, X, y=y, score_func=score_func, cv=cross_val.LeaveOneOut(y.size), n_jobs=1, verbose=0)
+    
+    assert Q2 > 0.
\ No newline at end of file