diff --git a/examples/gaussian_process/plot_gp_diabetes_dataset.py b/examples/gaussian_process/plot_gp_diabetes_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0337a2150f5007a610d6f981566524a82ff37adb
--- /dev/null
+++ b/examples/gaussian_process/plot_gp_diabetes_dataset.py
@@ -0,0 +1,64 @@
+#!/usr/bin/python
+# -*- coding: utf-8 -*-
+
+"""
+===============================================
+Gaussian Processes for Machine Learning example
+===============================================
+
+This example consists in fitting a Gaussian Process model onto the diabetes
+dataset.
+WARNING: This is quite time consuming (about 2 minutes overall runtime).
+
+The corelation parameters are given in order to maximize the generalization
+capacity of the model. We assumed an isotropic squared exponential correlation
+model (correxp2) with a constant regression model (regpoly0). We also used a
+nugget=1e-2 in order to account for the (strong) noise in the targets.
+
+The figure is a goodness-of-fit plot obtained using leave-one-out predictions
+of the Gaussian Process model. Based on these predictions, we compute an
+explained variance error (Q2).
+"""
+
+# Author: Vincent Dubourg <vincent.dubourg@gmail.com
+# License: BSD style
+
+from scikits.learn import datasets, cross_val, metrics
+from scikits.learn.gaussian_process import GaussianProcess
+from matplotlib import pyplot as pl
+
+# Print the docstring
+print __doc__
+
+# Load the dataset from scikits' data sets
+diabetes = datasets.load_diabetes()
+X, y = diabetes['data'], diabetes['target']
+
+# Instanciate a GP model
+gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False)
+
+# Fit the GP model to the data
+gp.fit(X, y)
+
+# Estimate the leave-one-out predictions using the cross_val module
+n_jobs = 2 # the distributing capacity available on the machine
+verbose = 1 # the verbosity level of the cross_val_score function
+y_pred = cross_val.cross_val_score(gp, X, y=y, \
+                                   cv=cross_val.LeaveOneOut(y.size), \
+                                   n_jobs=n_jobs, verbose=verbose).ravel() \
+       + y
+
+# Compute the empirical explained variance
+Q2 = metrics.explained_variance(y_pred, y)
+
+# Goodness-of-fit plot
+pl.figure()
+pl.title('Goodness-of-fit plot (Q2 = %1.2e)' % Q2)
+pl.plot(y, y_pred, 'r.', label='Leave-one-out')
+pl.plot(y, gp.predict(X), 'k.', label='Whole dataset (nugget=1e-2)')
+pl.plot([y.min(), y.max()], [y.min(), y.max()], 'k--')
+pl.xlabel('Observations')
+pl.ylabel('Predictions')
+pl.legend(loc='upper left')
+pl.axis('tight')
+pl.show()
diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py
index f06bfc50be7e08404285e223659a0ed464262ae3..9bb9fa1839963beb155ac69e5b711ad4f60ff2e8 100644
--- a/scikits/learn/gaussian_process/gaussian_process.py
+++ b/scikits/learn/gaussian_process/gaussian_process.py
@@ -660,10 +660,12 @@ class GaussianProcess(BaseEstimator):
             Q, G = linalg.qr(Ft, mode='economic')
             pass
 
-        rcondG = 1. / (linalg.norm(G) * linalg.norm(linalg.inv(G)))
+        sv = linalg.svd(G, compute_uv=False)
+        rcondG = sv[-1] / sv[0]
         if rcondG < 1e-10:
             # Check F
-            condF = linalg.norm(F) * linalg.norm(linalg.inv(F))
+            sv = linalg.svd(F, compute_uv=False)
+            condF = sv[0] / sv[-1]
             if condF > 1e15:
                 raise Exception("F is too ill conditioned. Poor combination " \
                               + "of regression model and observations.")
diff --git a/scikits/learn/tests/test_gaussian_process.py b/scikits/learn/tests/test_gaussian_process.py
index 6b0a3a79e7ea68e17e28c497f66c60f3f25fe4d0..99e81556979a7f219f34b71bf8a967164754a385 100644
--- a/scikits/learn/tests/test_gaussian_process.py
+++ b/scikits/learn/tests/test_gaussian_process.py
@@ -6,11 +6,8 @@ import numpy as np
 from numpy.testing import assert_array_equal, assert_array_almost_equal, \
                           assert_almost_equal, assert_raises, assert_
 
-from .. import datasets, cross_val, metrics
 from ..gaussian_process import GaussianProcess
 
-diabetes = datasets.load_diabetes()
-
 
 def test_regression_1d_x_sinx():
     """
@@ -28,30 +25,3 @@ def test_regression_1d_x_sinx():
     y_pred, MSE = gp.predict(X, eval_MSE=True)
 
     assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6))
-
-
-def test_regression_diabetes(n_jobs=1, verbose=0):
-    """
-    MLE estimation of a Gaussian Process model with an anisotropic squared
-    exponential correlation model.
-
-    Test the model using cross-validation module (quite time-consuming).
-
-    Poor performance: Leave-one-out estimate of explained variance is about 0.5
-    at best... To be investigated!
-
-    TODO: find a dataset that would prove GP performance!
-    """
-
-    X, y = diabetes['data'], diabetes['target']
-
-    gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y)
-
-    y_pred = cross_val.cross_val_score(gp, X, y=y, \
-                                       cv=cross_val.LeaveOneOut(y.size), \
-                                       n_jobs=n_jobs, verbose=verbose) \
-           + y
-
-    Q2 = metrics.explained_variance(y_pred, y)
-
-    assert Q2 > 0.45