From 5077393d5fe51baa962f2fd2507cadb47ebfb1ec Mon Sep 17 00:00:00 2001 From: dubourg <dubourg@PTlami14.(none)> Date: Sun, 14 Nov 2010 18:44:34 +0100 Subject: [PATCH] I Ran the PEP8 and PYFLAKES utils and corrected the gaussian_process module related files. --- ...ilistic_classification_after_regression.py | 94 ++- .../gaussian_process/plot_gp_regression.py | 32 +- scikits/learn/gaussian_process/__init__.py | 7 +- scikits/learn/gaussian_process/correlation.py | 206 ++++-- .../gaussian_process/gaussian_process.py | 690 +++++++++++------- scikits/learn/gaussian_process/regression.py | 62 +- scikits/learn/tests/test_gaussian_process.py | 41 +- 7 files changed, 685 insertions(+), 447 deletions(-) diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py index b8943a3413..99ea5ede72 100644 --- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py +++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py @@ -6,11 +6,15 @@ Gaussian Processes for Machine Learning example =============================================== -A two-dimensional regression exercise with a post-processing -allowing for probabilistic classification thanks to the -Gaussian property of the prediction. +A two-dimensional regression exercise with a post-processing allowing for +probabilistic classification thanks to the Gaussian property of the prediction. + +The figure illustrates the probability that the prediction is negative with +respect to the remaining uncertainty in the prediction. The red and blue lines +corresponds to the 95% confidence interval on the prediction of the zero level +set. """ -print __doc__ + # Author: Vincent Dubourg <vincent.dubourg@gmail.com # License: BSD style @@ -19,10 +23,9 @@ from scipy import stats from scikits.learn.gaussian_process import GaussianProcess from matplotlib import pyplot as pl from matplotlib import cm -from mpl_toolkits.mplot3d import Axes3D -class FormatFaker(object): - def __init__(self, str): self.str = str - def __mod__(self, stuff): return self.str + +# Print the docstring +print __doc__ # Standard normal distribution functions Grv = stats.distributions.norm() @@ -34,18 +37,19 @@ PHIinv = lambda x: Grv.ppf(x) beta0 = 8 b, kappa, e = 5, .5, .1 -# The function to predict (classification will then consist in predicting wheter g(x) <= 0 or not) -g = lambda x: b - x[:,1] - kappa * ( x[:,0] - e )**2. +# The function to predict (classification will then consist in predicting +# wheter g(x) <= 0 or not) +g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2. # Design of experiments X = np.array([[-4.61611719, -6.00099547], - [ 4.10469096, 5.32782448], - [ 0.00000000, -0.50000000], - [-6.17289014, -4.6984743 ], - [ 1.3109306 , -6.93271427], - [-5.03823144, 3.10584743], - [-2.87600388, 6.74310541], - [ 5.21301203, 4.26386883]]) + [4.10469096, 5.32782448], + [0.00000000, -0.50000000], + [-6.17289014, -4.6984743], + [1.3109306, -6.93271427], + [-5.03823144, 3.10584743], + [-2.87600388, 6.74310541], + [5.21301203, 4.26386883]]) # Observations Y = g(X) @@ -58,41 +62,53 @@ gp.fit(X, Y) # Evaluate real function, the prediction and its MSE on a grid res = 50 -x1, x2 = np.meshgrid(np.linspace(-beta0, beta0, res), np.linspace(-beta0, beta0, res)) +x1, x2 = np.meshgrid(np.linspace(- beta0, beta0, res), \ + np.linspace(- beta0, beta0, res)) xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T YY = g(xx) yy, MSE = gp.predict(xx, eval_MSE=True) sigma = np.sqrt(MSE) -yy = yy.reshape((res,res)) -YY = YY.reshape((res,res)) -sigma = sigma.reshape((res,res)) +yy = yy.reshape((res, res)) +YY = YY.reshape((res, res)) +sigma = sigma.reshape((res, res)) k = PHIinv(.975) -# Plot the probabilistic classification iso-values using the Gaussian property of the prediction +# Plot the probabilistic classification iso-values using the Gaussian property +# of the prediction fig = pl.figure(1) - ax = fig.add_subplot(111) ax.axes.set_aspect('equal') -cax = pl.imshow(np.flipud(PHI(-yy/sigma)), cmap=cm.gray_r, alpha=.8, extent=(-beta0,beta0,-beta0,beta0)) -norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9) -cb = pl.colorbar(cax, ticks=[0.,0.2,0.4,0.6,0.8,1.], norm=norm) -cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$') -pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12) -pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12) pl.xticks([]) pl.yticks([]) ax.set_xticklabels([]) ax.set_yticklabels([]) pl.xlabel('$x_1$') pl.ylabel('$x_2$') -cs = pl.contour(x1, x2, YY, [0.], colors='k', linestyles='dashdot') -pl.clabel(cs,fmt=FormatFaker(u'$g(\mathbf{x})=0$'),fontsize=11) -cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', linestyles='solid') -pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$'),fontsize=11) -cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', linestyles='dashed') -pl.clabel(cs,fmt=FormatFaker(u'$\mu_{\widehat{G}}(\mathbf{x})=0$'),fontsize=11) -cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', linestyles='solid') -pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$'),fontsize=11) - -pl.show() \ No newline at end of file + +cax = pl.imshow(np.flipud(PHI(- yy / sigma)), cmap=cm.gray_r, alpha=0.8, \ + extent=(- beta0, beta0, - beta0, beta0)) +norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9) +cb = pl.colorbar(cax, ticks=[0., 0.2, 0.4, 0.6, 0.8, 1.], norm=norm) +cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$') + +pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12) + +pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12) + +cs = pl.contour(x1, x2, YY, [0.], colors='k', \ + linestyles='dashdot') + +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', \ + linestyles='solid') +pl.clabel(cs, fontsize=11) + +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', \ + linestyles='dashed') +pl.clabel(cs, fontsize=11) + +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', \ + linestyles='solid') +pl.clabel(cs, fontsize=11) + +pl.show() diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py index e62207499c..80b5377231 100644 --- a/examples/gaussian_process/plot_gp_regression.py +++ b/examples/gaussian_process/plot_gp_regression.py @@ -6,21 +6,26 @@ Gaussian Processes for Machine Learning example =============================================== -A simple one-dimensional regression exercise with a -cubic correlation model whose parameters are estimated -using the maximum likelihood principle. +A simple one-dimensional regression exercise with a cubic correlation +model whose parameters are estimated using the maximum likelihood principle. + +The figure illustrates the interpolating property of the Gaussian Process +model as well as its probabilistic nature in the form of a pointwise 95% +confidence interval. """ -print __doc__ + # Author: Vincent Dubourg <vincent.dubourg@gmail.com # License: BSD style import numpy as np -from scipy import stats from scikits.learn.gaussian_process import GaussianProcess, corrcubic from matplotlib import pyplot as pl +# Print the docstring +print __doc__ + # The function to predict -f = lambda x: x*np.sin(x) +f = lambda x: x * np.sin(x) # The design of experiments X = np.array([1., 3., 5., 6., 7., 8.]) @@ -28,11 +33,13 @@ X = np.array([1., 3., 5., 6., 7., 8.]) # Observations Y = f(X) -# Mesh the input space for evaluations of the real function, the prediction and its MSE -x = np.linspace(0,10,1000) +# Mesh the input space for evaluations of the real function, the prediction and +# its MSE +x = np.linspace(0, 10, 1000) # Instanciate a Gaussian Process model -gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=100) +gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \ + random_start=100) # Fit to data using Maximum Likelihood Estimation of the parameters gp.fit(X, Y) @@ -41,12 +48,15 @@ gp.fit(X, Y) y, MSE = gp.predict(x, eval_MSE=True) sigma = np.sqrt(MSE) -# Plot the function, the prediction and the 95% confidence interval based on the MSE +# Plot the function, the prediction and the 95% confidence interval based on +# the MSE fig = pl.figure() pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.plot(X, Y, 'r.', markersize=10, label=u'Observations') pl.plot(x, y, 'b-', label=u'Prediction') -pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label='95% confidence interval') +pl.fill(np.concatenate([x, x[::-1]]), \ + np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), \ + alpha=.5, fc='b', ec='None', label='95% confidence interval') pl.xlabel('$x$') pl.ylabel('$f(x)$') pl.ylim(-10, 20) diff --git a/scikits/learn/gaussian_process/__init__.py b/scikits/learn/gaussian_process/__init__.py index 51cb685e18..f2bb68562f 100644 --- a/scikits/learn/gaussian_process/__init__.py +++ b/scikits/learn/gaussian_process/__init__.py @@ -4,11 +4,12 @@ """ This module contains a contribution to the scikit-learn project that implements Gaussian Process based prediction (also known as Kriging). - + The present implementation is based on a transliteration of the DACE Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>. """ from .gaussian_process import GaussianProcess -from .correlation import * -from .regression import * \ No newline at end of file +from .correlation import corrlin, corrcubic, correxp1, \ + correxp2, correxpg, corriid +from .regression import regpoly0, regpoly1, regpoly2 diff --git a/scikits/learn/gaussian_process/correlation.py b/scikits/learn/gaussian_process/correlation.py index c1d8235f22..65dcc5b825 100644 --- a/scikits/learn/gaussian_process/correlation.py +++ b/scikits/learn/gaussian_process/correlation.py @@ -7,33 +7,42 @@ import numpy as np + ############################# # Defaut correlation models # ############################# -def correxp1(theta,d): + +def correxp1(theta, d): """ Exponential autocorrelation model. (Ornstein-Uhlenbeck stochastic process) + n correxp1 : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i| ) i = 1 + Parameters ---------- theta : array_like - An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + An array with shape 1 (isotropic) or n (anisotropic) giving the + autocorrelation parameter(s). + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) containing the values of the autocorrelation model. + An array with shape (n_eval, ) containing the values of the + autocorrelation model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + if d.ndim > 1: n_features = d.shape[1] else: @@ -41,36 +50,44 @@ def correxp1(theta,d): if theta.size == 1: theta = np.repeat(theta, n_features) elif theta.size != n_features: - raise ValueError, "Length of theta must be 1 or "+str(n_features) - + raise ValueError("Length of theta must be 1 or " + str(n_features)) + td = - theta.reshape(1, n_features) * abs(d) - r = np.exp(np.sum(td,1)) - + r = np.exp(np.sum(td, 1)) + return r -def correxp2(theta,d): + +def correxp2(theta, d): """ Squared exponential correlation model (Radial Basis Function). (Infinitely differentiable stochastic process, very smooth) + n correxp2 : theta, dx --> r(theta, dx) = exp( sum - theta_i * (dx_i)^2 ) i = 1 + Parameters ---------- theta : array_like - An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + An array with shape 1 (isotropic) or n (anisotropic) giving the + autocorrelation parameter(s). + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) containing the values of the autocorrelation model. + An array with shape (n_eval, ) containing the values of the + autocorrelation model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + if d.ndim > 1: n_features = d.shape[1] else: @@ -78,36 +95,45 @@ def correxp2(theta,d): if theta.size == 1: theta = np.repeat(theta, n_features) elif theta.size != n_features: - raise Exception, "Length of theta must be 1 or "+str(n_features) - - td = - theta.reshape(1, n_features) * d**2 - r = np.exp(np.sum(td,1)) - + raise Exception("Length of theta must be 1 or " + str(n_features)) + + td = - theta.reshape(1, n_features) * d ** 2 + r = np.exp(np.sum(td, 1)) + return r -def correxpg(theta,d): + +def correxpg(theta, d): """ Generalized exponential correlation model. - (Useful when one does not know the smoothness of the function to be predicted.) + (Useful when one does not know the smoothness of the function to be + predicted.) + n correxpg : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i|^p ) i = 1 + Parameters ---------- theta : array_like - An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the autocorrelation parameter(s) (theta, p). + An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the + autocorrelation parameter(s) (theta, p). + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) with the values of the autocorrelation model. + An array with shape (n_eval, ) with the values of the autocorrelation + model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + if d.ndim > 1: n_features = d.shape[1] else: @@ -115,125 +141,151 @@ def correxpg(theta,d): lth = theta.size if n_features > 1 and lth == 2: theta = np.hstack([np.repeat(theta[0], n_features), theta[1]]) - elif lth != n_features+1: - raise Exception, "Length of theta must be 2 or "+str(n_features+1) + elif lth != n_features + 1: + raise Exception("Length of theta must be 2 or " + str(n_features + 1)) else: theta = theta.reshape(1, lth) - - td = - theta[:,0:-1].reshape(1, n_features) * abs(d)**theta[:,-1] - r = np.exp(np.sum(td,1)) - + + td = - theta[:, 0:-1].reshape(1, n_features) * abs(d) ** theta[:, -1] + r = np.exp(np.sum(td, 1)) + return r -def corriid(theta,d): + +def corriid(theta, d): """ Spatial independence correlation model (pure nugget). (Useful when one wants to solve an ordinary least squares problem!) + n - corriid : theta, dx --> r(theta, dx) = 1 if sum |dx_i| == 0 and 0 otherwise + corriid : theta, dx --> r(theta, dx) = 1 if sum |dx_i| == 0 i = 1 + 0 otherwise + Parameters ---------- theta : array_like None. + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) with the values of the autocorrelation model. + An array with shape (n_eval, ) with the values of the autocorrelation + model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + n_eval = d.shape[0] r = np.zeros(n_eval) # The ones on the diagonal of the correlation matrix are enforced within # the KrigingModel instanciation to allow multiple design sites in this # ordinary least squares context. - + return r -def corrcubic(theta,d): + +def corrcubic(theta, d): """ Cubic correlation model. - n - corrcubic : theta, dx --> r(theta, dx) = prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) , i = 1,...,m - j = 1 + + corrcubic : theta, dx --> r(theta, dx) = + n + prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) , i = 1,...,m + j = 1 + Parameters ---------- theta : array_like - An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + An array with shape 1 (isotropic) or n (anisotropic) giving the + autocorrelation parameter(s). + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) with the values of the autocorrelation model. + An array with shape (n_eval, ) with the values of the autocorrelation + model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + if d.ndim > 1: n_features = d.shape[1] else: n_features = 1 lth = theta.size if lth == 1: - theta = np.repeat(theta, n_features)[np.newaxis,:] + theta = np.repeat(theta, n_features)[np.newaxis][:] elif lth != n_features: - raise Exception, "Length of theta must be 1 or "+str(n_features) + raise Exception("Length of theta must be 1 or " + str(n_features)) else: theta = theta.reshape(1, n_features) - + td = abs(d) * theta td[td > 1.] = 1. - ss = 1. - td**2. * (3. - 2.*td) - r = np.prod(ss,1) - + ss = 1. - td ** 2. * (3. - 2. * td) + r = np.prod(ss, 1) + return r -def corrlin(theta,d): + +def corrlin(theta, d): """ Linear correlation model. - n - corrlin : theta, dx --> r(theta, dx) = prod max(0, 1 - theta_j*d_ij) , i = 1,...,m - j = 1 + + corrlin : theta, dx --> r(theta, dx) = + n + prod max(0, 1 - theta_j*d_ij) , i = 1,...,m + j = 1 + Parameters ---------- theta : array_like - An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + An array with shape 1 (isotropic) or n (anisotropic) giving the + autocorrelation parameter(s). + dx : array_like - An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. - + An array with shape (n_eval, n_features) giving the componentwise + distances between locations x and x' at which the correlation model + should be evaluated. + Returns ------- r : array_like - An array with shape (n_eval, ) with the values of the autocorrelation model. + An array with shape (n_eval, ) with the values of the autocorrelation + model. """ - + theta = np.asanyarray(theta, dtype=np.float) d = np.asanyarray(d, dtype=np.float) - + if d.ndim > 1: n_features = d.shape[1] else: n_features = 1 lth = theta.size if lth == 1: - theta = np.repeat(theta, n_features)[np.newaxis,:] + theta = np.repeat(theta, n_features)[np.newaxis][:] elif lth != n_features: - raise Exception, "Length of theta must be 1 or "+str(n_features) + raise Exception("Length of theta must be 1 or " + str(n_features)) else: theta = theta.reshape(1, n_features) - + td = abs(d) * theta td[td > 1.] = 1. ss = 1. - td - r = np.prod(ss,1) - - return r \ No newline at end of file + r = np.prod(ss, 1) + + return r diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py index 0d07a05cea..f06bfc50be 100644 --- a/scikits/learn/gaussian_process/gaussian_process.py +++ b/scikits/learn/gaussian_process/gaussian_process.py @@ -6,43 +6,48 @@ ################ import numpy as np -from scipy import linalg, optimize, random +from scipy import linalg, optimize, rand from ..base import BaseEstimator from .regression import regpoly0 -from .correlation import correxp2 +from .correlation import correxp2, corriid machine_epsilon = np.finfo(np.double).eps + def col_sum(x): """ Performs columnwise sums of elements in x depending on its shape. - + Parameters ---------- x : array_like - An array of size (p, q). - + An array with shape size (p, q). + Returns ------- s : array_like - An array of size (q, ) which contains the columnwise sums of the elements in x. + An array whit shape (q, ) which contains the columnwise sums of the + elements in x. """ - + x = np.asanyarray(x, dtype=np.float) - + if x.ndim > 1: s = x.sum(axis=0) else: s = x - + return s + ############################## # The Gaussian Process class # ############################## + class GaussianProcess(BaseEstimator): """ - A class that implements scalar Gaussian Process based prediction (also known as Kriging). + A class that implements scalar Gaussian Process based prediction (also + known as Kriging). Example ------- @@ -53,7 +58,8 @@ class GaussianProcess(BaseEstimator): f = lambda x: x*np.sin(x) X = np.array([1., 3., 5., 6., 7., 8.]) Y = f(X) - gp = GaussianProcess(regr=regpoly0, corr=correxp2, theta0=1e-1, thetaL=1e-3, thetaU=1e0, random_start=100) + gp = GaussianProcess(regr=regpoly0, corr=correxp2, theta0=1e-1, \ + thetaL=1e-3, thetaU=1e0, random_start=100) gp.fit(X, Y) pl.figure(1) @@ -63,17 +69,19 @@ class GaussianProcess(BaseEstimator): pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') pl.plot(X, Y, 'r.', markersize=10, label=u'Observations') pl.plot(x, y, 'k-', label=u'$\widehat{f}(x) = {\\rm BLUP}(x)$') - pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label=u'95\% confidence interval') + pl.fill(np.concatenate([x, x[::-1]]), \ + np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), \ + alpha=.5, fc='b', ec='None', label=u'95\% confidence interval') pl.xlabel('$x$') pl.ylabel('$f(x)$') pl.legend(loc='upper left') pl.figure(2) - theta_values = np.logspace(np.log10(gp.thetaL),np.log10(gp.thetaU),100) - reduced_likelihood_function_values = [] + theta_values = np.logspace(np.log10(gp.thetaL), np.log10(gp.thetaU),100) + psi_values = [] for t in theta_values: - reduced_likelihood_function_values.append(gp.reduced_likelihood_function(theta=t)[0]) - pl.plot(theta_values, reduced_likelihood_function_values) + psi_values.append(gp.reduced_likelihood_function(theta=t)[0]) + pl.plot(theta_values, psi_values) pl.xlabel(u'$\\theta$') pl.ylabel(u'Score') pl.xscale('log') @@ -90,43 +98,59 @@ class GaussianProcess(BaseEstimator): Todo ---- - o Add the 'sparse' storage mode for which the correlation matrix is stored in its sparse eigen decomposition format instead of its full Cholesky decomposition. + o Add the 'sparse' storage mode for which the correlation matrix is stored + in its sparse eigen decomposition format instead of its full Cholesky + decomposition. Implementation details ---------------------- - The presentation implementation is based on a translation of the DACE Matlab toolbox. - + The presentation implementation is based on a translation of the DACE + Matlab toolbox. + See references: - [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). DACE - A MATLAB Kriging Toolbox. + [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). + DACE - A MATLAB Kriging Toolbox. http://www2.imm.dtu.dk/~hbn/dace/dace.pdf """ - - def __init__(self, regr = regpoly0, corr = correxp2, beta0 = None, storage_mode = 'full', verbose = True, theta0 = 1e-1, thetaL = None, thetaU = None, optimizer = 'fmin_cobyla', random_start = 1, normalize = True, nugget = 10.*machine_epsilon): + + def __init__(self, regr=regpoly0, corr=correxp2, beta0=None, \ + storage_mode='full', verbose=True, theta0=1e-1, \ + thetaL=None, thetaU=None, optimizer='fmin_cobyla', \ + random_start=1, normalize=True, \ + nugget=10. * machine_epsilon): """ The Gaussian Process model constructor. Parameters ---------- regr : lambda function, optional - A regression function returning an array of outputs of the linear regression functional basis. - (The number of observations m should be greater than the size p of this basis) + A regression function returning an array of outputs of the linear + regression functional basis. The number of observations n_samples + should be greater than the size p of this basis. Default assumes a simple constant regression trend (see regpoly0). - + corr : lambda function, optional - A stationary autocorrelation function returning the autocorrelation between two points x and x'. - Default assumes a squared-exponential autocorrelation model (see correxp2). - + A stationary autocorrelation function returning the autocorrelation + between two points x and x'. + Default assumes a squared-exponential autocorrelation model (see + correxp2). + beta0 : double array_like, optional - Default assumes Universal Kriging (UK) so that the vector beta of regression weights is estimated - by Maximum Likelihood. Specifying beta0 overrides estimation and performs Ordinary Kriging (OK). - + The regression weight vector to perform Ordinary Kriging (OK). + Default assumes Universal Kriging (UK) so that the vector beta of + regression weights is estimated using the maximum likelihood + principle. + storage_mode : string, optional - A string specifying whether the Cholesky decomposition of the correlation matrix should be - stored in the class (storage_mode = 'full') or not (storage_mode = 'light'). - Default assumes storage_mode = 'full', so that the Cholesky decomposition of the correlation matrix is stored. - This might be a useful parameter when one is not interested in the MSE and only plan to estimate the BLUP, - for which the correlation matrix is not required. - + A string specifying whether the Cholesky decomposition of the + correlation matrix should be stored in the class (storage_mode = + 'full') or not (storage_mode = 'light'). + Default assumes storage_mode = 'full', so that the + Cholesky decomposition of the correlation matrix is stored. + This might be a useful parameter when one is not interested in the + MSE and only plan to estimate the BLUP, for which the correlation + matrix is not required. + verbose : boolean, optional A boolean specifying the verbose level. Default is verbose = True. @@ -134,37 +158,49 @@ class GaussianProcess(BaseEstimator): theta0 : double array_like, optional An array with shape (n_features, ) or (1, ). The parameters in the autocorrelation model. - If thetaL and thetaU are also specified, theta0 is considered as the starting point - for the Maximum Likelihood Estimation of the best set of parameters. + If thetaL and thetaU are also specified, theta0 is considered as + the starting point for the maximum likelihood rstimation of the + best set of parameters. Default assumes isotropic autocorrelation model with theta0 = 1e-1. - + thetaL : double array_like, optional An array with shape matching theta0's. - Lower bound on the autocorrelation parameters for Maximum Likelihood Estimation. - Default is None, skips Maximum Likelihood Estimation and uses theta0. - + Lower bound on the autocorrelation parameters for maximum + likelihood estimation. + Default is None, so that it skips maximum likelihood estimation and + it uses theta0. + thetaU : double array_like, optional An array with shape matching theta0's. - Upper bound on the autocorrelation parameters for Maximum Likelihood Estimation. - Default is None and skips Maximum Likelihood Estimation and uses theta0. - + Upper bound on the autocorrelation parameters for maximum + likelihood estimation. + Default is None, so that it skips maximum likelihood estimation and + it uses theta0. + normalize : boolean, optional - Design sites X and observations y are centered and reduced wrt means and standard deviations - estimated from the n_samples observations provided. - Default is true so that data is normalized to ease MLE. - + Design sites X and observations y are centered and reduced wrt + means and standard deviations estimated from the n_samples + observations provided. + Default is normalize = True so that data is normalized to ease + maximum likelihood estimation. + nugget : double, optional - Introduce a nugget effect to allow smooth predictions from noisy data. - Default assumes a nugget close to machine precision for the sake of robustness (nugget = 10.*machine_epsilon). - + Introduce a nugget effect to allow smooth predictions from noisy + data. + Default assumes a nugget close to machine precision for the sake of + robustness (nugget = 10.*machine_epsilon). + optimizer : string, optional - A string specifying the optimization algorithm to be used ('fmin_cobyla' is the sole algorithm implemented yet). + A string specifying the optimization algorithm to be used + ('fmin_cobyla' is the sole algorithm implemented yet). Default uses 'fmin_cobyla' algorithm from scipy.optimize. - + random_start : int, optional - The number of times the Maximum Likelihood Estimation should be performed from a random starting point. + The number of times the Maximum Likelihood Estimation should be + performed from a random starting point. The first MLE always uses the specified starting point (theta0), - the next starting points are picked at random according to an exponential distribution (log-uniform in [thetaL, thetaU]). + the next starting points are picked at random according to an + exponential distribution (log-uniform on [thetaL, thetaU]). Default does not use random starting point (random_start = 1). Returns @@ -172,52 +208,57 @@ class GaussianProcess(BaseEstimator): gp : self A Gaussian Process model object awaiting data to be fitted to. """ - + self.regr = regr self.corr = corr self.beta0 = beta0 - + # Check storage mode if storage_mode != 'full' and storage_mode != 'light': if storage_mode == 'sparse': - raise NotImplementedError, "The 'sparse' storage mode is not supported yet. Please contribute!" + raise NotImplementedError("The 'sparse' storage mode is not " \ + + "supported yet. Please contribute!") else: - raise ValueError, "Storage mode should either be 'full' or 'light'. Unknown storage mode: "+str(storage_mode) + raise ValueError("Storage mode should either be 'full' or " \ + + "'light'. Unknown storage mode: " + str(storage_mode)) else: self.storage_mode = storage_mode - + self.verbose = verbose - + # Check correlation parameters self.theta0 = np.atleast_2d(np.asanyarray(theta0, dtype=np.float)) self.thetaL, self.thetaU = thetaL, thetaU lth = self.theta0.size - + if self.thetaL is not None and self.thetaU is not None: # Parameters optimization case self.thetaL = np.atleast_2d(np.asanyarray(thetaL, dtype=np.float)) self.thetaU = np.atleast_2d(np.asanyarray(thetaU, dtype=np.float)) - + if self.thetaL.size != lth or self.thetaU.size != lth: - raise ValueError, "theta0, thetaL and thetaU must have the same length" + raise ValueError("theta0, thetaL and thetaU must have the " \ + + "same length") if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL): - raise ValueError, "The bounds must satisfy O < thetaL <= thetaU" - + raise ValueError("The bounds must satisfy O < thetaL <= " \ + + "thetaU") + elif self.thetaL is None and self.thetaU is None: # Given parameters case if np.any(self.theta0 <= 0): - raise ValueError, "theta0 must be strictly positive" - + raise ValueError("theta0 must be strictly positive") + elif self.thetaL is None or self.thetaU is None: # Error - raise Exception, "thetaL and thetaU should either be both or neither specified" - + raise Exception("thetaL and thetaU should either be both or " \ + + "neither specified") + # Store other parameters self.normalize = normalize self.nugget = nugget self.optimizer = optimizer self.random_start = int(random_start) - + def fit(self, X, y): """ The Gaussian Process model fitting method. @@ -225,21 +266,24 @@ class GaussianProcess(BaseEstimator): Parameters ---------- X : double array_like - An array with shape (n_samples, n_features) with the design sites at which observations were made. - + An array with shape (n_samples, n_features) with the design sites + at which observations were made. + y : double array_like - An array with shape (n_features, ) with the observations of the scalar output to be predicted. - + An array with shape (n_features, ) with the observations of the + scalar output to be predicted. + Returns ------- gp : self - A fitted Gaussian Process model object awaiting data to perform predictions. + A fitted Gaussian Process model object awaiting data to perform + predictions. """ - + # Force data to numpy.array type (from coding guidelines) X = np.asanyarray(X, dtype=np.float) y = np.asanyarray(y, dtype=np.float) - + # Check design sites & observations n_samples_X = X.shape[0] if X.ndim > 1: @@ -247,23 +291,25 @@ class GaussianProcess(BaseEstimator): else: n_features = 1 X = X.reshape(n_samples_X, n_features) - + n_samples_y = y.shape[0] if y.ndim > 1: - raise NotImplementedError, "y has more than one dimension. This is not supported yet (scalar output prediction only). Please contribute!" + raise NotImplementedError("y has more than one dimension. This " \ + + "is not supported yet (scalar output prediction only). " \ + + "Please contribute!") y = y.reshape(n_samples_y, 1) - + if n_samples_X != n_samples_y: - raise Exception, "X and y must have the same number of rows!" + raise Exception("X and y must have the same number of rows!") else: n_samples = n_samples_X - + # Normalize data or don't if self.normalize: mean_X = np.mean(X, axis=0) - std_X = np.sqrt(1./(n_samples-1.)*np.sum((X - np.mean(X,0))**2.,0)) #np.std(y, axis=0) + std_X = np.std(X, axis=0) mean_y = np.mean(y, axis=0) - std_y = np.sqrt(1./(n_samples-1.)*np.sum((y - np.mean(y,0))**2.,0)) #np.std(y, axis=0) + std_y = np.std(y, axis=0) std_X[std_X == 0.] = 1. std_y[std_y == 0.] = 1. else: @@ -271,22 +317,23 @@ class GaussianProcess(BaseEstimator): std_X = np.array([1.]) mean_y = np.array([0.]) std_y = np.array([1.]) - + X_ = (X - mean_X) / std_X y_ = (y - mean_y) / std_y - + # Calculate matrix of distances D between samples - mzmax = n_samples*(n_samples-1)/2 + mzmax = n_samples * (n_samples - 1) / 2 ij = np.zeros([mzmax, 2]) D = np.zeros([mzmax, n_features]) ll = np.array([-1]) for k in range(n_samples-1): - ll = ll[-1] + 1 + range(n_samples-k-1) - ij[ll,:] = np.concatenate([[np.repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T - D[ll,:] = X_[k,:] - X_[(k+1):n_samples,:] - if (np.min(np.sum(np.abs(D),1)) == 0.) and (self.corr != corriid): - raise Exception, "Multiple X are not allowed" - + ll = ll[-1] + 1 + range(n_samples - k - 1) + ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \ + [np.arange(k + 1, n_samples).T]]).T + D[ll] = X_[k] - X_[(k + 1):n_samples] + if np.min(np.sum(np.abs(D), 1)) == 0. and self.corr != corriid: + raise Exception("Multiple X are not allowed") + # Regression matrix and parameters F = self.regr(X) n_samples_F = F.shape[0] @@ -295,12 +342,17 @@ class GaussianProcess(BaseEstimator): else: p = 1 if n_samples_F != n_samples: - raise Exception, "Number of rows in F and X do not match. Most likely something is wrong in the regression model." + raise Exception("Number of rows in F and X do not match. Most " \ + + "likely something is going wrong with the " \ + + "regression model.") if p > n_samples_F: - raise Exception, "Ordinary least squares problem is undetermined. n_samples=%d must be greater than the regression model size p=%d!" % (n_samples, p) - if self.beta0 is not None and (self.beta0.shape[0] != p or self.beta0.ndim > 1): - raise Exception, "Shapes of beta0 and F do not match." - + raise Exception("Ordinary least squares problem is undetermined " \ + + "n_samples=%d must be greater than the " \ + + "regression model size p=%d!" % (n_samples, p)) + if self.beta0 is not None \ + and (self.beta0.shape[0] != p or self.beta0.ndim > 1): + raise Exception("Shapes of beta0 and F do not match.") + # Set attributes self.X = X self.y = y @@ -309,365 +361,453 @@ class GaussianProcess(BaseEstimator): self.D = D self.ij = ij self.F = F - self.X_sc = np.concatenate([[mean_X],[std_X]]) - self.y_sc = np.concatenate([[mean_y],[std_y]]) - + self.X_sc = np.concatenate([[mean_X], [std_X]]) + self.y_sc = np.concatenate([[mean_y], [std_y]]) + # Determine Gaussian Process model parameters if self.thetaL is not None and self.thetaU is not None: # Maximum Likelihood Estimation of the parameters if self.verbose: - print "Performing Maximum Likelihood Estimation of the autocorrelation parameters..." - self.theta, self.reduced_likelihood_function_value, self.par = self.arg_max_reduced_likelihood_function() - if np.isinf(self.reduced_likelihood_function_value) : - raise Exception, "Bad parameter region. Try increasing upper bound" + print "Performing Maximum Likelihood Estimation of the " \ + + "autocorrelation parameters..." + self.theta, self.reduced_likelihood_function_value, self.par = \ + self.arg_max_reduced_likelihood_function() + if np.isinf(self.reduced_likelihood_function_value): + raise Exception("Bad parameter region. " \ + + "Try increasing upper bound") + else: # Given parameters if self.verbose: - print "Given autocorrelation parameters. Computing Gaussian Process model parameters..." + print "Given autocorrelation parameters. " \ + + "Computing Gaussian Process model parameters..." self.theta = self.theta0 - self.reduced_likelihood_function_value, self.par = self.reduced_likelihood_function() + self.reduced_likelihood_function_value, self.par = \ + self.reduced_likelihood_function() if np.isinf(self.reduced_likelihood_function_value): - raise Exception, "Bad point. Try increasing theta0" - + raise Exception("Bad point. Try increasing theta0") + if self.storage_mode == 'light': # Delete heavy data (it will be computed again if required) # (it is required only when MSE is wanted in self.predict) if self.verbose: - print "Light storage mode specified. Flushing autocorrelation matrix..." + print "Light storage mode specified. " \ + + "Flushing autocorrelation matrix..." self.D = None self.ij = None self.F = None self.par['C'] = None self.par['Ft'] = None self.par['G'] = None - + return self - + def predict(self, X, eval_MSE=False, batch_size=None): """ This function evaluates the Gaussian Process model at x. - + Parameters ---------- X : array_like - An array with shape (n_eval, n_features) giving the point(s) at which the prediction(s) should be made. + An array with shape (n_eval, n_features) giving the point(s) at + which the prediction(s) should be made. + eval_MSE : boolean, optional - A boolean specifying whether the Mean Squared Error should be evaluated or not. - Default assumes evalMSE = False and evaluates only the BLUP (mean prediction). + A boolean specifying whether the Mean Squared Error should be + evaluated or not. + Default assumes evalMSE = False and evaluates only the BLUP (mean + prediction). + verbose : boolean, optional A boolean specifying the verbose level. Default is verbose = True. + batch_size : integer, optional - An integer giving the maximum number of points that can be evaluated simulatneously (depending on the available memory). - Default is None so that all given points are evaluated at the same time. - + An integer giving the maximum number of points that can be + evaluated simulatneously (depending on the available memory). + Default is None so that all given points are evaluated at the same + time. + Returns ------- y : array_like - An array with shape (n_eval, ) with the Best Linear Unbiased Prediction at x. + An array with shape (n_eval, ) with the Best Linear Unbiased + Prediction at x. + MSE : array_like, optional (if eval_MSE == True) An array with shape (n_eval, ) with the Mean Squared Error at x. """ - + # Check if np.any(np.isnan(self.par['beta'])): - raise Exception, "Not a valid GaussianProcess. Try fitting it again with different parameters theta" - + raise Exception("Not a valid GaussianProcess. " \ + + "Try fitting it again with different parameters " \ + + "theta.") + # Check design & trial sites X = np.asanyarray(X, dtype=np.float) n_samples = self.X_.shape[0] if self.X_.ndim > 1: n_features = self.X_.shape[1] else: - n = 1 + n_features = 1 n_eval = X.shape[0] if X.ndim > 1: n_features_X = X.shape[1] else: n_features_X = 1 X = X.reshape(n_eval, n_features_X) - + if n_features_X != n_features: - raise ValueError, "The number of features in X (X.shape[1] = %d) should match the sample size used for fit() which is %d." % (n_features_X, n_features) - - if batch_size is None: # No memory management (evaluates all given points in a single batch run) - + raise ValueError("The number of features in X (X.shape[1] = %d) " \ + + "should match the sample size used for fit() " \ + + "which is %d." % (n_features_X, n_features)) + + if batch_size is None: + # No memory management + # (evaluates all given points in a single batch run) + # Normalize trial sites - X_ = (X - self.X_sc[0,:]) / self.X_sc[1,:] - + X_ = (X - self.X_sc[0][:]) / self.X_sc[1][:] + # Initialize output y = np.zeros(n_eval) if eval_MSE: MSE = np.zeros(n_eval) - + # Get distances to design sites - dx = np.zeros([n_eval*n_samples,n_features]) + dx = np.zeros([n_eval * n_samples, n_features]) kk = np.arange(n_samples).astype(int) for k in range(n_eval): - dx[kk,:] = X_[k,:] - self.X_ + dx[kk] = X_[k] - self.X_ kk = kk + n_samples - + # Get regression function and correlation f = self.regr(X_) r = self.corr(self.theta, dx).reshape(n_eval, n_samples).T - + # Scaled predictor - y_ = np.matrix(f) * np.matrix(self.par['beta']) + (np.matrix(self.par['gamma']) * np.matrix(r)).T - + y_ = np.matrix(f) * np.matrix(self.par['beta']) \ + + (np.matrix(self.par['gamma']) * np.matrix(r)).T + # Predictor - y = (self.y_sc[0,:] + self.y_sc[1,:] * np.array(y_)).reshape(n_eval) - + y = (self.y_sc[0] + self.y_sc[1] * np.array(y_)).ravel() + # Mean Squared Error if eval_MSE: par = self.par if par['C'] is None: # Light storage mode (need to recompute C, F, Ft and G) if self.verbose: - print "This GaussianProcess used light storage mode at instanciation. Need to recompute autocorrelation matrix..." - reduced_likelihood_function_value, par = self.reduced_likelihood_function() - + print "This GaussianProcess used light storage mode " \ + + "at instanciation. Need to recompute " \ + + "autocorrelation matrix..." + reduced_likelihood_function_value, par = \ + self.reduced_likelihood_function() + rt = linalg.solve(np.matrix(par['C']), np.matrix(r)) if self.beta0 is None: # Universal Kriging - u = linalg.solve(np.matrix(-self.par['G'].T), np.matrix(self.par['Ft']).T * np.matrix(rt) - np.matrix(f).T) + u = - linalg.solve(np.matrix(self.par['G'].T), \ + np.matrix(self.par['Ft']).T * \ + np.matrix(rt) - np.matrix(f).T) else: # Ordinary Kriging - u = 0. * y - - MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt)**2.) + col_sum(np.array(u)**2.)).T - - # Mean Squared Error might be slightly negative depending on machine precision - # Force to zero! + u = np.zeros(y.shape) + + MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt) ** 2.) \ + + col_sum(np.array(u) ** 2.)).T + + # Mean Squared Error might be slightly negative depending on + # machine precision: force to zero! MSE[MSE < 0.] = 0. - + return y, MSE - + else: - + return y - - else: # Memory management - + + else: + # Memory management + if type(batch_size) is not int or batch_size <= 0: - raise Exception, "batch_size must be a positive integer" - + raise Exception("batch_size must be a positive integer") + if eval_MSE: - - y, MSE = np.zeros(n_eval), np.zeros(n_eval) - for k in range(n_eval/batch_size): - y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])], MSE[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \ - self.predict(X[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None) - + + y, MSE = np.array([]), np.array([]) + for k in range(n_eval / batch_size): + batch_from = k * batch_size + batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) + X_batch = X[batch_from:batch_to][:] + y_batch, MSE_batch = \ + self.predict(X_batch, eval_MSE=eval_MSE, \ + batch_size=None) + y.append(y_batch) + MSE.append(MSE_batch) + return y, MSE - + else: - - y = np.zeros(n_eval) - for k in range(n_eval/batch_size): - y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \ - self.__call__(x[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None) - + + y = np.array([]) + for k in range(n_eval / batch_size): + batch_from = k * batch_size + batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) + X_batch = X[batch_from:batch_to][:] + y_batch = \ + self.predict(X_batch, eval_MSE=eval_MSE, \ + batch_size=None) + return y - - - - def reduced_likelihood_function(self, theta = None): + + def reduced_likelihood_function(self, theta=None): """ - This function determines the BLUP parameters and evaluates the reduced likelihood function for the given autocorrelation parameters theta. - Maximizing this function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the - assumed joint Gaussian distribution of the observations y evaluated onto the design of experiments X. - + This function determines the BLUP parameters and evaluates the reduced + likelihood function for the given autocorrelation parameters theta. + + Maximizing this function wrt the autocorrelation parameters theta is + equivalent to maximizing the likelihood of the assumed joint Gaussian + distribution of the observations y evaluated onto the design of + experiments X. + Parameters ---------- theta : array_like, optional - An array containing the autocorrelation parameters at which the Gaussian Process model parameters should be determined. - Default uses the built-in autocorrelation parameters (ie theta = self.theta). - + An array containing the autocorrelation parameters at which the + Gaussian Process model parameters should be determined. + Default uses the built-in autocorrelation parameters + (ie theta = self.theta). + Returns ------- + reduced_likelihood_function_value : double + The value of the reduced likelihood function associated to the + given autocorrelation parameters theta. + par : dict - A dictionary containing the requested Gaussian Process model parameters: + A dictionary containing the requested Gaussian Process model + parameters: + par['sigma2'] : Gaussian Process variance. - par['beta'] : Generalized least-squares regression weights for Universal Kriging or given beta0 for Ordinary Kriging. + par['beta'] : Generalized least-squares regression weights for + Universal Kriging or given beta0 for Ordinary + Kriging. par['gamma'] : Gaussian Process weights. par['C'] : Cholesky decomposition of the correlation matrix [R]. par['Ft'] : Solution of the linear equation system : [R] x Ft = F par['G'] : QR decomposition of the matrix Ft. - par['detR'] : Determinant of the correlation matrix raised at power 1/n_samples. - reduced_likelihood_function_value : double - The value of the reduced likelihood function associated to the given autocorrelation parameters theta. + par['detR'] : Determinant of the correlation matrix raised at power + 1/n_samples. """ - + if theta is None: # Use built-in autocorrelation parameters theta = self.theta - - reduced_likelihood_function_value = -np.inf - par = { } - + + # Initialize output + reduced_likelihood_function_value = - np.inf + par = {} + # Retrieve data n_samples = self.X_.shape[0] D = self.D ij = self.ij F = self.F - + if D is None: # Light storage mode (need to recompute D, ij and F) if self.X_.ndim > 1: n_features = self.X_.shape[1] else: n_features = 1 - mzmax = n_samples*(n_samples-1)/2 - ij = zeros([mzmax, n_features]) - D = zeros([mzmax, n_features]) - ll = array([-1]) + mzmax = n_samples * (n_samples - 1) / 2 + ij = np.zeros([mzmax, n_features]) + D = np.zeros([mzmax, n_features]) + ll = np.array([-1]) for k in range(n_samples-1): - ll = ll[-1]+1 + range(n_samples-k-1) - ij[ll,:] = np.concatenate([[repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T - D[ll,:] = self.X_[k,:] - self.X_[(k+1):n_samples,:] - if (min(sum(abs(D),1)) == 0.) and (self.corr != corriid): - raise Exception, "Multiple X are not allowed" + ll = ll[-1] + 1 + range(n_samples - k - 1) + ij[ll] = \ + np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \ + [np.arange(k + 1, n_samples).T]]).T + D[ll] = self.X_[k] - self.X_[(k + 1):n_samples] + if min(sum(abs(D), 1)) == 0. and self.corr != corriid: + raise Exception("Multiple X are not allowed") F = self.regr(self.X_) - + # Set up R r = self.corr(theta, D) R = np.eye(n_samples) * (1. + self.nugget) - R[ij.astype(int)[:,0], ij.astype(int)[:,1]] = r - R[ij.astype(int)[:,1], ij.astype(int)[:,0]] = r - + R[ij.astype(int)[:, 0], ij.astype(int)[:, 1]] = r + R[ij.astype(int)[:, 1], ij.astype(int)[:, 0]] = r + # Cholesky decomposition of R try: C = linalg.cholesky(R, lower=True) except linalg.LinAlgError: return reduced_likelihood_function_value, par - + # Get generalized least squares solution Ft = linalg.solve(C, F) try: Q, G = linalg.qr(Ft, econ=True) except: - #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177: DeprecationWarning: qr econ argument will be removed after scipy 0.7. The economy transform will then be available through the mode='economic' argument. + #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177: + # DeprecationWarning: qr econ argument will be removed after scipy + # 0.7. The economy transform will then be available through the + # mode='economic' argument. Q, G = linalg.qr(Ft, mode='economic') pass - - rcondG = 1./(linalg.norm(G)*linalg.norm(linalg.inv(G))) + + rcondG = 1. / (linalg.norm(G) * linalg.norm(linalg.inv(G))) if rcondG < 1e-10: # Check F - condF = linalg.norm(F)*linalg.norm(linalg.inv(F)) + condF = linalg.norm(F) * linalg.norm(linalg.inv(F)) if condF > 1e15: - raise Exception, "F is too ill conditioned. Poor combination of regression model and observations." + raise Exception("F is too ill conditioned. Poor combination " \ + + "of regression model and observations.") else: - # Ft is too ill conditioned + # Ft is too ill conditioned, get out (try different theta) return reduced_likelihood_function_value, par - - Yt = linalg.solve(C,self.y_) + + Yt = linalg.solve(C, self.y_) if self.beta0 is None: # Universal Kriging - beta = linalg.solve(G, np.matrix(Q).T*np.matrix(Yt)) + beta = linalg.solve(G, np.matrix(Q).T * np.matrix(Yt)) else: # Ordinary Kriging beta = np.array(self.beta0) - + rho = np.matrix(Yt) - np.matrix(Ft) * np.matrix(beta) - normalized_sigma2 = (np.array(rho)**2.).sum(axis=0)/n_samples - # The determinant of R is equal to the squared product of the diagonal elements of its Cholesky decomposition C - detR = (np.array(np.diag(C))**(2./n_samples)).prod() + normalized_sigma2 = (np.array(rho) ** 2.).sum(axis=0) / n_samples + # The determinant of R is equal to the squared product of the diagonal + # elements of its Cholesky decomposition C + detR = (np.array(np.diag(C)) ** (2. / n_samples)).prod() + + # Compute/Organize output reduced_likelihood_function_value = - normalized_sigma2.sum() * detR - par = { 'sigma2':normalized_sigma2 * self.y_sc[1]**2, \ - 'beta':beta, \ - 'gamma':linalg.solve(C.T,rho).T, \ - 'C':C, \ - 'Ft':Ft, \ - 'G':G } - + par['sigma2'] = normalized_sigma2 * self.y_sc[1] ** 2 + par['beta'] = beta + par['gamma'] = linalg.solve(C.T, rho).T + par['C'] = C + par['Ft'] = Ft + par['G'] = G + return reduced_likelihood_function_value, par def arg_max_reduced_likelihood_function(self): """ - This function estimates the autocorrelation parameters theta as the maximizer of the reduced likelihood function. - (Minimization of the opposite reduced likelihood function is used for convenience) - + This function estimates the autocorrelation parameters theta as the + maximizer of the reduced likelihood function. + (Minimization of the opposite reduced likelihood function is used for + convenience) + Parameters ---------- self : All parameters are stored in the Gaussian Process model object. - + Returns ------- optimal_theta : array_like - The best set of autocorrelation parameters (the sought maximizer of the reduced likelihood function). + The best set of autocorrelation parameters (the sought maximizer of + the reduced likelihood function). + optimal_reduced_likelihood_function_value : double The optimal reduced likelihood function value. + optimal_par : dict The BLUP parameters associated to thetaOpt. """ - + + # Initialize output + best_optimal_theta = [] + best_optimal_rlf_value = [] + best_optimal_par = [] + if self.verbose: - print "The chosen optimizer is: "+str(self.optimizer) + print "The chosen optimizer is: " + str(self.optimizer) if self.random_start > 1: - print str(self.random_start)+" random starts are required." - + print str(self.random_start) + " random starts are required." + percent_completed = 0. - + if self.optimizer == 'fmin_cobyla': - - minus_reduced_likelihood_function = lambda log10t: - self.reduced_likelihood_function(theta = 10.**log10t)[0] - + + minus_reduced_likelihood_function = lambda log10t: \ + - self.reduced_likelihood_function(theta=10. ** log10t)[0] + constraints = [] for i in range(self.theta0.size): - constraints.append(lambda log10t: log10t[i] - np.log10(self.thetaL[0,i])) - constraints.append(lambda log10t: np.log10(self.thetaU[0,i]) - log10t[i]) - + constraints.append(lambda log10t: \ + log10t[i] - np.log10(self.thetaL[0, i])) + constraints.append(lambda log10t: \ + np.log10(self.thetaU[0, i]) - log10t[i]) + for k in range(self.random_start): - + if k == 0: # Use specified starting point as first guess theta0 = self.theta0 else: - # Generate a random starting point log10-uniformly distributed between bounds - log10theta0 = np.log10(self.thetaL) + random.rand(self.theta0.size).reshape(self.theta0.shape) * np.log10(self.thetaU/self.thetaL) - theta0 = 10.**log10theta0 - - log10_optimal_theta = optimize.fmin_cobyla(minus_reduced_likelihood_function, np.log10(theta0), constraints, iprint=0) - optimal_theta = 10.**log10_optimal_theta - optimal_minus_reduced_likelihood_function_value, optimal_par = self.reduced_likelihood_function(theta = optimal_theta) - optimal_reduced_likelihood_function_value = - optimal_minus_reduced_likelihood_function_value - + # Generate a random starting point log10-uniformly + # distributed between bounds + log10theta0 = np.log10(self.thetaL) \ + + rand(self.theta0.size).reshape(self.theta0.shape) \ + * np.log10(self.thetaU / self.thetaL) + theta0 = 10. ** log10theta0 + + # Run Cobyla + log10_optimal_theta = \ + optimize.fmin_cobyla(minus_reduced_likelihood_function, \ + np.log10(theta0), constraints, iprint=0) + + optimal_theta = 10. ** log10_optimal_theta + optimal_minus_rlf_value, optimal_par = \ + self.reduced_likelihood_function(theta=optimal_theta) + optimal_rlf_value = - optimal_minus_rlf_value + + # Compare the new optimizer to the best previous one if k > 0: - if optimal_reduced_likelihood_function_value > best_optimal_reduced_likelihood_function_value: - best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value + if optimal_rlf_value > best_optimal_rlf_value: + best_optimal_rlf_value = optimal_rlf_value best_optimal_par = optimal_par best_optimal_theta = optimal_theta else: - best_optimal_reduced_likelihood_function_value = optimal_reduced_likelihood_function_value + best_optimal_rlf_value = optimal_rlf_value best_optimal_par = optimal_par best_optimal_theta = optimal_theta if self.verbose and self.random_start > 1: - if (20*k)/self.random_start > percent_completed: - percent_completed = (20*k)/self.random_start - print str(5*percent_completed)+"% completed" - + if (20 * k) / self.random_start > percent_completed: + percent_completed = (20 * k) / self.random_start + print str(5 * percent_completed) + "% completed" + else: - - raise NotImplementedError, "This optimizer ('%s') is not implemented yet. Please contribute!" % self.optimizer - - return best_optimal_theta, best_optimal_reduced_likelihood_function_value, best_optimal_par - + + raise NotImplementedError("This optimizer ('%s') is not " \ + + "implemented yet. Please contribute!" \ + % self.optimizer) + + return best_optimal_theta, best_optimal_rlf_value, best_optimal_par + def score(self, X_test, y_test): """ - This score function returns the deviations of the Gaussian Process model evaluated onto a test dataset. - + This score function returns the deviations of the Gaussian Process + model evaluated onto a test dataset. + Parameters ---------- X_test : array_like The feature test dataset with shape (n_tests, n_features). - + y_test : array_like The target test dataset (n_tests, ). - + Returns ------- score_values : array_like - The deviations between the prediction and the targets : y_pred - y_test. + The deviations between the prediction and the targets: + y_pred - y_test. """ - - return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test \ No newline at end of file + + return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test diff --git a/scikits/learn/gaussian_process/regression.py b/scikits/learn/gaussian_process/regression.py index 5bf8df7d32..58bf14948b 100644 --- a/scikits/learn/gaussian_process/regression.py +++ b/scikits/learn/gaussian_process/regression.py @@ -7,78 +7,88 @@ import numpy as np + ############################ # Defaut regression models # ############################ + def regpoly0(x): """ Zero order polynomial (constant, p = 1) regression model. - + regpoly0 : x --> f(x) = 1 - + Parameters ---------- x : array_like - An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. - + An array with shape (n_eval, n_features) giving the locations x at + which the regression model should be evaluated. + Returns ------- f : array_like - An array with shape (n_eval, p) with the values of the regression model. + An array with shape (n_eval, p) with the values of the regression + model. """ - + x = np.asanyarray(x, dtype=np.float) n_eval = x.shape[0] - f = np.ones([n_eval,1]) - + f = np.ones([n_eval, 1]) + return f + def regpoly1(x): """ First order polynomial (hyperplane, p = n) regression model. - + regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T - + Parameters ---------- x : array_like - An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. - + An array with shape (n_eval, n_features) giving the locations x at + which the regression model should be evaluated. + Returns ------- f : array_like - An array with shape (n_eval, p) with the values of the regression model. + An array with shape (n_eval, p) with the values of the regression + model. """ - + x = np.asanyarray(x, dtype=np.float) n_eval = x.shape[0] - f = np.hstack([np.ones([n_eval,1]), x]) - + f = np.hstack([np.ones([n_eval, 1]), x]) + return f + def regpoly2(x): """ Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model. - + regpoly2 : x --> f(x) = [ x_i*x_j, (i,j) = 1,...,n ].T i > j - + Parameters ---------- x : array_like - An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. - + An array with shape (n_eval, n_features) giving the locations x at + which the regression model should be evaluated. + Returns ------- f : array_like - An array with shape (n_eval, p) with the values of the regression model. + An array with shape (n_eval, p) with the values of the regression + model. """ - + x = np.asanyarray(x, dtype=np.float) n_eval, n_features = x.shape - f = np.hstack([np.ones([n_eval,1]), x]) + f = np.hstack([np.ones([n_eval, 1]), x]) for k in range(n_features): - f = np.hstack([f, x[:,k,np.newaxis] * x[:,k:]]) - - return f \ No newline at end of file + f = np.hstack([f, x[:, k, np.newaxis] * x[:, k:]]) + + return f diff --git a/scikits/learn/tests/test_gaussian_process.py b/scikits/learn/tests/test_gaussian_process.py index bb5ec0c32c..6b0a3a79e7 100644 --- a/scikits/learn/tests/test_gaussian_process.py +++ b/scikits/learn/tests/test_gaussian_process.py @@ -6,43 +6,52 @@ import numpy as np from numpy.testing import assert_array_equal, assert_array_almost_equal, \ assert_almost_equal, assert_raises, assert_ -from .. import gaussian_process, datasets, cross_val, metrics +from .. import datasets, cross_val, metrics +from ..gaussian_process import GaussianProcess diabetes = datasets.load_diabetes() + def test_regression_1d_x_sinx(): """ MLE estimation of a Gaussian Process model with a squared exponential correlation model (correxp2). Check random start optimization. - + Test the interpolating property. """ - + f = lambda x: x * np.sin(x) X = np.array([1., 3., 5., 6., 7., 8.]) y = f(X) - gp = gaussian_process.GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=10, verbose=False).fit(X, y) + gp = GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \ + random_start=10, verbose=False).fit(X, y) y_pred, MSE = gp.predict(X, eval_MSE=True) - - assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6)) -def test_regression_diabetes(n_jobs = 1, verbose = 0): + assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6)) + + +def test_regression_diabetes(n_jobs=1, verbose=0): """ - MLE estimation of a Gaussian Process model with an anisotropic squared exponential - correlation model. - + MLE estimation of a Gaussian Process model with an anisotropic squared + exponential correlation model. + Test the model using cross-validation module (quite time-consuming). - + Poor performance: Leave-one-out estimate of explained variance is about 0.5 at best... To be investigated! + TODO: find a dataset that would prove GP performance! """ - + X, y = diabetes['data'], diabetes['target'] - - gp = gaussian_process.GaussianProcess(corr=gaussian_process.correxp2, theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y) - y_pred = cross_val.cross_val_score(gp, X, y=y, cv=cross_val.LeaveOneOut(y.size), n_jobs=n_jobs, verbose=verbose) + y + gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y) + + y_pred = cross_val.cross_val_score(gp, X, y=y, \ + cv=cross_val.LeaveOneOut(y.size), \ + n_jobs=n_jobs, verbose=verbose) \ + + y + Q2 = metrics.explained_variance(y_pred, y) - + assert Q2 > 0.45 -- GitLab