diff --git a/doc/modules/gpml.rst b/doc/modules/gpml.rst new file mode 100644 index 0000000000000000000000000000000000000000..23e99f9f85ec30958fb01311858f255af360ca93 --- /dev/null +++ b/doc/modules/gpml.rst @@ -0,0 +1,213 @@ +======================================= +Gaussian Processes for Machine Learning +======================================= + +.. currentmodule:: scikits.learn.gpml + +**Gaussian Processes for Machine Learning (GPML)** is a supervised learning +method used for *regression*. It can also be used for *probabilistic classification*, +but it is only a post-processing of the *regression* exercise. + +The advantages of Gaussian Processes for Machine Learning are: + + - The prediction interpolates the observations (at least for regular + correlation models). + + - The prediction is probabilistic (Gaussian) so that one can compute + empirical confidence intervals and exceedence probabilities that might be + used to refit (online fitting, adaptive fitting) the prediction in some + region of interest. + + - Versatile: different :ref:`linear regression models <linear_regression_models>` and + :ref:`correlation models <correlation_models>` can be specified. Common models are + provided, but it is also possible to specify custom models. + +The disadvantages of Gaussian Processes for Machine Learning include: + + - It is not sparse. It uses the whole samples/features information to + perform the prediction. + + - It loses efficiency in high dimensional spaces -- namely when the number + of features exceeds a few dozens. It might indeed give poor performance + and it becomes computationally inefficient. + + - Classification is only a post-processing, meaning that one first need + to solve a regression problem by providing the complete scalar float + precision output :math:`y` of the computer experiment one attempt to model. + +Thanks to the Gaussian property of the prediction, it has been given varied +applications: e.g. for global optimization, probabilistic classification. + + +Mathematical formulation +======================== + +The initial assumption +~~~~~~~~~~~~~~~~~~~~~~ + +Suppose one wants to model the output of a computer experiment, say a +mathematical function: + +.. math:: + + g: & \mathbb{R}^{n_{\rm features}} \rightarrow \mathbb{R} \\ + & X \mapsto y = g(X) + +GPML starts with the assumption that this function is a conditionnal sample path +of a Gaussian process :math:`G` which is additionally assumed to read as follows: + +.. math:: + + G(X) = f(X)^T \beta + Z(X) + +where :math:`f(X)^T \beta` is a linear regression model and :math:`Z(X)` is a zero-mean Gaussian +process with a fully stationary covariance function: + +.. math:: + + C(X, X') = \sigma^2 R(|X - X'|) + +:math:`\sigma^2` being its variance and :math:`R` being the correlation function which solely +depends on the absolute relative distance between each sample -- possibly featurewise. + +From this basic formulation, note that GPML is nothing but an extension of a +basic least squares linear regression problem: + +.. math:: + + g(X) \approx f(X)^T \beta + +Except we additionaly assume some spatial coherence (correlation) between the +samples dictated by the correlation function. Indeed, ordinary least squares assumes the +correlation model :math:`R(|X - X'|)` is one when :math:`X = X'` and zero otherwise : a *dirac* correlation +model -- sometimes referred to as a *nugget* correlation model in the kriging literature. + + +The best linear unbiased prediction (BLUP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We now derive the *best linear unbiased prediction* of the sample path :math:`g` +conditioned by the observations: + +.. math:: + + \hat{G}(X) \sim G(X | y_1 = g(X_1), ..., y_{n_{\rm samples}} = g(X_{n_{\rm samples}})) + +It is derived from its *given properties*: + +- It is linear (a linear combination of the observations) + +.. math:: + + \hat{G}(X) \equiv a(X)^T y + +- It is unbiased + +.. math:: + + \mathbb{E}[G(X) - \hat{G}(X)] = 0 + +- It is the best (in the Mean Squared Error sense) + +.. math:: + + \hat{G}(X)^* = \arg \min\limits_{\hat{G}(X)} \; \mathbb{E}[(G(X) - \hat{G}(X))^2] + +So that the optimal weight vector :math:`a(X)` is solution of the following equality constrained optimization +problem: + +.. math:: + + a(X)^* = \arg \min\limits_{a(X)} & \; \mathbb{E}[(G(X) - a(X)^T y)^2] \\ + {\rm s. t.} & \; \mathbb{E}[G(X) - a(X)^T y] = 0 + +Rewriting this constrained optimization problem in the form of a Lagrangian and looking further for +the first order optimality conditions to be satisfied, one ends up with a closed form expression for the +sought predictor -- see references for the complete proof. + +In the end, the BLUP is shown to be a Gaussian random variate whose moments expressions are given in reference. + + + +The empirical best linear unbiased predictor (EBLUP) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Until now, both the autocorrelation and regression models were assumed given. In +practice however they are never known in advance so that one has to make (motivated) +empirical choices for these models :ref:`correlation_models`. + +Provided these choices are made, one should estimate the remaining unknown +parameters involved in the BLUP. To do so, one uses the set of provided observations +in conjunction with some inference technique. The present implementation, which is based +on the DACE's Matlab toolbox uses the *maximum likelihood estimation* technique. + +For a more comprehensive description of the theoretical aspects of Gaussian +Processes for Machine Learning, please refer to the references below: + +.. topic:: References: + + * *"DACE, A Matlab Kriging Toolbox"* + S Lophaven, HB Nielsen, J Sondergaard + 2002, + <http://www2.imm.dtu.dk/~hbn/dace/> + + * *"Gaussian Processes for Machine Learning"* + CE Rasmussen, CKI Williams + MIT Press, 2006 (Ed. T Diettrich) + <http://www.gaussianprocess.org/gpml/chapters/RW.pdf> + + * *"The design and analysis of computer experiments"* + TJ Santner, BJ Williams, W Notz + Springer, 2003 + <http://www.stat.osu.edu/~comp_exp/book.html> + +.. correlation_models:: + +Correlation Models +================== + +Common correlation models matches some famous SVM's kernels because they are mostly built on the equivalent +assumptions. They must fulfill Mercer's conditions. Note however, that the choice of the correlation model +should be made in agreement with the known properties of the original experiment from which the observations +come. + +* If the original experiment is known to be infinitely differentiable (smooth), then one should use the *squared-exponential correlation model*. +* If it's not, then one should rather use the *exponential correlation model*. +* Note also that there exists a correlation model that takes the degree of derivability as input: this is the Matern correlation model, but it's not implemented here. + +For a more detailed discussion on the selection of the appropriate correlation models, dee the book by Rasmussen & Williams in references. + +.. regression_models:: + +Regression Models +================= + +Common linear regression models involve zero (constant), first- and second-order polynomials. But one may +specify its own in the form of a Python function that takes the features X as input and that returns a vector +containing the values of the functional set. The only constraint is that the number of functions must not exceed the +number of available observations so that the underlying regression problem is not *under-determined*. + + +An introductory example +======================= + +Say we want to surrogate the function :math:`g(x) = x \sin(x)`. To do so, the function is evaluated onto a +design of experiments. Then, we define a GaussianProcessModel whose regression and correlation +models might be specified using additional kwargs, and ask for the model to be fitted to the data. Depending on the number +of parameters provided at instanciation, the fitting procedure may recourse to maximum likelihood estimation for the parameters +or alternatively it uses the given parameters. + + + +Implementation details +====================== + +The present implementation is based on a transliteration of the DACE Matlab +toolbox. + +.. topic:: References: + + * *"DACE, A Matlab Kriging Toolbox"* + S Lophaven, HB Nielsen, J Sondergaard + 2002, + <http://www2.imm.dtu.dk/~hbn/dace/> diff --git a/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py b/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..553ca21f97851e528730d7285954daf18a6ed9f9 --- /dev/null +++ b/examples/gpml/plot_gpml_probabilistic_classification_after_regression.py @@ -0,0 +1,113 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" +=============================================== +Gaussian Processes for Machine Learning example +=============================================== + +A two-dimensional regression exercise with a post-processing +allowing for probabilistic classification thanks to the +Gaussian property of the prediction. +""" +# Author: Vincent Dubourg <vincent.dubourg@gmail.com +# License: BSD style + +import numpy as np +from scipy import stats +from scikits.learn.gpml import GaussianProcessModel +from matplotlib import pyplot as pl +from matplotlib import cm +from mpl_toolkits.mplot3d import Axes3D +class FormatFaker(object): + def __init__(self, str): self.str = str + def __mod__(self, stuff): return self.str + +# Standard normal distribution functions +Grv = stats.distributions.norm() +phi = lambda x: Grv.pdf(x) +PHI = lambda x: Grv.cdf(x) +PHIinv = lambda x: Grv.ppf(x) + +# A few constants +beta0 = 8 +b, kappa, e = 5, .5, .1 + +# The function to predict (classification will then consist in predicting wheter g(x) <= 0 or not) +g = lambda x: b - x[:,1] - kappa * ( x[:,0] - e )**2. + +# Design of experiments +X = np.array([[-4.61611719, -6.00099547], + [ 4.10469096, 5.32782448], + [ 0.00000000, -0.50000000], + [-6.17289014, -4.6984743 ], + [ 1.3109306 , -6.93271427], + [-5.03823144, 3.10584743], + [-2.87600388, 6.74310541], + [ 5.21301203, 4.26386883]]) + +# Observations +Y = g(X) + +# Instanciate and fit Gaussian Process Model +aGaussianProcessModel = GaussianProcessModel(theta0=5e-1) + +# Don't perform MLE or you'll get a perfect prediction for this simple example! +aGaussianProcessModel.fit(X, Y) + +# Evaluate real function, the prediction and its MSE on a grid +res = 50 +x1, x2 = np.meshgrid(np.linspace(-beta0, beta0, res), np.linspace(-beta0, beta0, res)) +xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T + +YY = g(xx) +yy, MSE = aGaussianProcessModel.predict(xx, eval_MSE=True) +sigma = np.sqrt(MSE) +yy = yy.reshape((res,res)) +YY = YY.reshape((res,res)) +sigma = sigma.reshape((res,res)) +k = PHIinv(.975) + +# Plot the probabilistic classification iso-values using the Gaussian property of the prediction +fig = pl.figure(1) +ax = fig.add_subplot(111) +ax.axes.set_aspect('equal') +cax = pl.imshow(np.flipud(PHI(-yy/sigma)), cmap=cm.gray_r, alpha=.8, extent=(-beta0,beta0,-beta0,beta0)) +norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9) +cb = pl.colorbar(cax, ticks=[0.,0.2,0.4,0.6,0.8,1.], norm=norm) +cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$') +pl.plot(X[Y <= 0, 0], X[Y <= 0, 1], 'r.', markersize=12) +pl.plot(X[Y > 0, 0], X[Y > 0, 1], 'b.', markersize=12) +pl.xticks([]) +pl.yticks([]) +ax.set_xticklabels([]) +ax.set_yticklabels([]) +pl.xlabel('$x_1$') +pl.ylabel('$x_2$') +print u'Click to place label: $g(\mathbf{x})=0$' +cs = pl.contour(x1, x2, YY, [0.], colors='k', linestyles='dashdot') +pl.clabel(cs,fmt=FormatFaker(u'$g(\mathbf{x})=0$'),fontsize=11,manual=True) +print u'Click to place label: ${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$' +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.025], colors='b', linestyles='solid') +pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 2.5\%$'),fontsize=11,manual=True) +print u'Click to place label: $\mu_{\widehat{G}}(\mathbf{x})=0$' +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.5], colors='k', linestyles='dashed') +pl.clabel(cs,fmt=FormatFaker(u'$\mu_{\widehat{G}}(\mathbf{x})=0$'),fontsize=11,manual=True) +print u'Click to place label: ${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$' +cs = pl.contour(x1, x2, PHI(-yy/sigma), [0.975], colors='r', linestyles='solid') +pl.clabel(cs,fmt=FormatFaker(u'${\\rm \mathbb{P}}\left[{\widehat{G}}(\mathbf{x}) \leq 0\\right]= 97.5\%$'),fontsize=11,manual=True) + +# Plot the prediction and the bounds of the 95% confidence interval +fig = pl.figure(2) +ax = Axes3D(fig) +ax.axes.set_aspect('equal') +ax.plot_surface(x1, x2, yy, linewidth = 0.5, rstride = 1, cstride = 1, color = 'k', alpha = .8) +ax.plot_surface(x1, x2, yy - k*sigma, linewidth = 0.5, rstride = 1, cstride = 1, color = 'b', alpha = .8) +ax.plot_surface(x1, x2, yy + k*sigma, linewidth = 0.5, rstride = 1, cstride = 1, color = 'r', alpha = .8) +ax.scatter3D(X[Y <= 0, 0], X[Y <= 0, 1], Y[Y <= 0], 'r.', s = 20) +ax.scatter3D(X[Y > 0, 0], X[Y > 0, 1], Y[Y > 0], 'b.', s = 20) +ax.set_xlabel(u'$x_1$') +ax.set_ylabel(u'$x_2$') +ax.set_zlabel(u'$\widehat{G}(x_1,\,x_2)$') + +pl.show() \ No newline at end of file diff --git a/examples/gpml/plot_gpml_regression.py b/examples/gpml/plot_gpml_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..a67f35a41fe6677555f2a39fdf1523356bad8bec --- /dev/null +++ b/examples/gpml/plot_gpml_regression.py @@ -0,0 +1,78 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" +=============================================== +Gaussian Processes for Machine Learning example +=============================================== + +A simple one-dimensional regression exercise with +different correlation models and maximum likelihood +estimation of the Gaussian Process Model parameters. +""" +# Author: Vincent Dubourg <vincent.dubourg@gmail.com +# License: BSD style + +import numpy as np +from scipy import stats +from scikits.learn.gpml import GaussianProcessModel, correxp1, correxp2, corrlin, corrcubic +from matplotlib import pyplot as pl + +# The function to predict +f = lambda x: x*np.sin(x) + +# The design of experiments +X = np.array([1., 3., 5., 6., 7., 8.]) + +# Observations +Y = f(X) + +# Mesh the input space for evaluations of the real function, the prediction and its MSE +x = np.linspace(0,10,1000) + +# Loop correlation models +corrs = (correxp1, correxp2, corrlin, corrcubic) +colors = ('b', 'g', 'y', 'm') +for k in range(len(corrs)): + + # Instanciate a Gaussian Process Model with the k-th correlation model + if corrs[k] == corrlin or corrs[k] == corrcubic: + aGaussianProcessModel = GaussianProcessModel(corr=corrs[k], theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=100) + else: + aGaussianProcessModel = GaussianProcessModel(corr=corrs[k], theta0=1e-2, thetaL=1e-4, thetaU=1e+1, random_start=100) + + # Fit to data using Maximum Likelihood Estimation of the parameters + aGaussianProcessModel.fit(X, Y) + + # Make the prediction on the meshed x-axis (ask for MSE as well) + y, MSE = aGaussianProcessModel.predict(x, eval_MSE=True) + sigma = np.sqrt(MSE) + + # Compute the score function on a grid of the autocorrelation parameter space + theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL), np.log10(aGaussianProcessModel.thetaU), 100) + score_values = [] + for t in theta_values: + score_values.append(aGaussianProcessModel.score(theta=t)[0]) + + fig = pl.figure() + + # Plot the function, the prediction and the 95% confidence interval based on the MSE + ax = fig.add_subplot(211) + pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') + pl.plot(X, Y, 'r.', markersize=10, label=u'Observations') + pl.plot(x, y, colors[k]+'-', label=u'Prediction (%s)' % corrs[k].__name__) + pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc=colors[k], ec='None', label='95% confidence interval') + pl.xlabel('$x$') + pl.ylabel('$f(x)$') + pl.ylim(-10, 20) + pl.legend(loc='upper left') + + # Plot the score function + ax = fig.add_subplot(212) + pl.plot(theta_values, score_values, colors[k]+'-') + pl.xlabel(u'$\\theta$') + pl.ylabel(u'Score') + pl.xscale('log') + pl.xlim(aGaussianProcessModel.thetaL[0], aGaussianProcessModel.thetaU[0]) + +pl.show() diff --git a/scikits/learn/gpml/__init__.py b/scikits/learn/gpml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bcba115a3744cd3478e15a22bfdd14285266bfe9 --- /dev/null +++ b/scikits/learn/gpml/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +""" + This module contains a contribution to the scikit-learn project that + implements Gaussian Process based prediction (also known as Kriging). + + The present implementation is based on a transliteration of the DACE + Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>. +""" + +from .gaussian_process_model import GaussianProcessModel +from .correlation_models import * +from .regression_models import * \ No newline at end of file diff --git a/scikits/learn/gpml/correlation_models.py b/scikits/learn/gpml/correlation_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d8235f22bf695cdf8a28e6252b609f19afa525 --- /dev/null +++ b/scikits/learn/gpml/correlation_models.py @@ -0,0 +1,239 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +################ +# Dependencies # +################ + +import numpy as np + +############################# +# Defaut correlation models # +############################# + +def correxp1(theta,d): + """ + Exponential autocorrelation model. + (Ornstein-Uhlenbeck stochastic process) + n + correxp1 : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i| ) + i = 1 + Parameters + ---------- + theta : array_like + An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) containing the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + if d.ndim > 1: + n_features = d.shape[1] + else: + n_features = 1 + if theta.size == 1: + theta = np.repeat(theta, n_features) + elif theta.size != n_features: + raise ValueError, "Length of theta must be 1 or "+str(n_features) + + td = - theta.reshape(1, n_features) * abs(d) + r = np.exp(np.sum(td,1)) + + return r + +def correxp2(theta,d): + """ + Squared exponential correlation model (Radial Basis Function). + (Infinitely differentiable stochastic process, very smooth) + n + correxp2 : theta, dx --> r(theta, dx) = exp( sum - theta_i * (dx_i)^2 ) + i = 1 + Parameters + ---------- + theta : array_like + An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) containing the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + if d.ndim > 1: + n_features = d.shape[1] + else: + n_features = 1 + if theta.size == 1: + theta = np.repeat(theta, n_features) + elif theta.size != n_features: + raise Exception, "Length of theta must be 1 or "+str(n_features) + + td = - theta.reshape(1, n_features) * d**2 + r = np.exp(np.sum(td,1)) + + return r + +def correxpg(theta,d): + """ + Generalized exponential correlation model. + (Useful when one does not know the smoothness of the function to be predicted.) + n + correxpg : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i|^p ) + i = 1 + Parameters + ---------- + theta : array_like + An array with shape 1+1 (isotropic) or n+1 (anisotropic) giving the autocorrelation parameter(s) (theta, p). + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) with the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + if d.ndim > 1: + n_features = d.shape[1] + else: + n_features = 1 + lth = theta.size + if n_features > 1 and lth == 2: + theta = np.hstack([np.repeat(theta[0], n_features), theta[1]]) + elif lth != n_features+1: + raise Exception, "Length of theta must be 2 or "+str(n_features+1) + else: + theta = theta.reshape(1, lth) + + td = - theta[:,0:-1].reshape(1, n_features) * abs(d)**theta[:,-1] + r = np.exp(np.sum(td,1)) + + return r + +def corriid(theta,d): + """ + Spatial independence correlation model (pure nugget). + (Useful when one wants to solve an ordinary least squares problem!) + n + corriid : theta, dx --> r(theta, dx) = 1 if sum |dx_i| == 0 and 0 otherwise + i = 1 + Parameters + ---------- + theta : array_like + None. + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) with the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + n_eval = d.shape[0] + r = np.zeros(n_eval) + # The ones on the diagonal of the correlation matrix are enforced within + # the KrigingModel instanciation to allow multiple design sites in this + # ordinary least squares context. + + return r + +def corrcubic(theta,d): + """ + Cubic correlation model. + n + corrcubic : theta, dx --> r(theta, dx) = prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) , i = 1,...,m + j = 1 + Parameters + ---------- + theta : array_like + An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) with the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + if d.ndim > 1: + n_features = d.shape[1] + else: + n_features = 1 + lth = theta.size + if lth == 1: + theta = np.repeat(theta, n_features)[np.newaxis,:] + elif lth != n_features: + raise Exception, "Length of theta must be 1 or "+str(n_features) + else: + theta = theta.reshape(1, n_features) + + td = abs(d) * theta + td[td > 1.] = 1. + ss = 1. - td**2. * (3. - 2.*td) + r = np.prod(ss,1) + + return r + +def corrlin(theta,d): + """ + Linear correlation model. + n + corrlin : theta, dx --> r(theta, dx) = prod max(0, 1 - theta_j*d_ij) , i = 1,...,m + j = 1 + Parameters + ---------- + theta : array_like + An array with shape 1 (isotropic) or n (anisotropic) giving the autocorrelation parameter(s). + dx : array_like + An array with shape (n_eval, n_features) giving the componentwise distances between locations x and x' at which the correlation model should be evaluated. + + Returns + ------- + r : array_like + An array with shape (n_eval, ) with the values of the autocorrelation model. + """ + + theta = np.asanyarray(theta, dtype=np.float) + d = np.asanyarray(d, dtype=np.float) + + if d.ndim > 1: + n_features = d.shape[1] + else: + n_features = 1 + lth = theta.size + if lth == 1: + theta = np.repeat(theta, n_features)[np.newaxis,:] + elif lth != n_features: + raise Exception, "Length of theta must be 1 or "+str(n_features) + else: + theta = theta.reshape(1, n_features) + + td = abs(d) * theta + td[td > 1.] = 1. + ss = 1. - td + r = np.prod(ss,1) + + return r \ No newline at end of file diff --git a/scikits/learn/gpml/gaussian_process_model.py b/scikits/learn/gpml/gaussian_process_model.py new file mode 100644 index 0000000000000000000000000000000000000000..368dbbb40d0d8fdc9a3d3fec084eba02c1b6fe69 --- /dev/null +++ b/scikits/learn/gpml/gaussian_process_model.py @@ -0,0 +1,653 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +################ +# Dependencies # +################ + +import numpy as np +from scipy import linalg, optimize, random +from ..base import BaseEstimator +from .regression_models import regpoly0 +from .correlation_models import correxp2 +machine_epsilon = np.finfo(np.double).eps + +def col_sum(x): + """ + Performs columnwise sums of elements in x depending on its shape. + + Parameters + ---------- + x : array_like + An array of size (p, q). + + Returns + ------- + s : array_like + An array of size (q, ) which contains the columnwise sums of the elements in x. + """ + + x = np.asanyarray(x, dtype=np.float) + + if x.ndim > 1: + s = x.sum(axis=0) + else: + s = x + + return s + +#################################### +# The Gaussian Process Model class # +#################################### + +class GaussianProcessModel(BaseEstimator): + """ + A class that implements scalar Gaussian Process based prediction (also known as Kriging). + + Example + ------- + import numpy as np + from scikits.learn.gpml import GaussianProcessModel + import pylab as pl + + f = lambda x: x*np.sin(x) + X = np.array([1., 3., 5., 6., 7., 8.]) + Y = f(X) + aGaussianProcessModel = GaussianProcessModel(regr=regpoly0, corr=correxp2, theta0=1e-1, thetaL=1e-3, thetaU=1e0, random_start=100) + aGaussianProcessModel.fit(X, Y) + + pl.figure(1) + x = np.linspace(0,10,1000) + y, MSE = aGaussianProcessModel.predict(x, eval_MSE=True) + sigma = np.sqrt(MSE) + pl.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$') + pl.plot(X, Y, 'r.', markersize=10, label=u'Observations') + pl.plot(x, y, 'k-', label=u'$\widehat{f}(x) = {\\rm BLUP}(x)$') + pl.fill(np.concatenate([x, x[::-1]]), np.concatenate([y - 1.9600 * sigma, (y + 1.9600 * sigma)[::-1]]), alpha=.5, fc='b', ec='None', label=u'95\% confidence interval') + pl.xlabel('$x$') + pl.ylabel('$f(x)$') + pl.legend(loc='upper left') + + pl.figure(2) + theta_values = np.logspace(np.log10(aGaussianProcessModel.thetaL),np.log10(aGaussianProcessModel.thetaU),100) + score_values = [] + for t in theta_values: + score_values.append(aGaussianProcessModel.score(theta=t)[0]) + pl.plot(theta_values, score_values) + pl.xlabel(u'$\\theta$') + pl.ylabel(u'Score') + pl.xscale('log') + + pl.show() + + Methods + ------- + fit(X, y) : self + Fit the model. + + predict(X) : array + Predict using the model. + + Todo + ---- + o Add the 'sparse' storage mode for which the correlation matrix is stored in its sparse eigen decomposition format instead of its full Cholesky decomposition. + + Implementation details + ---------------------- + The presentation implementation is based on a translation of the DACE Matlab toolbox. + + See references: + [1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). DACE - A MATLAB Kriging Toolbox. + http://www2.imm.dtu.dk/~hbn/dace/dace.pdf + """ + + def __init__(self, regr = regpoly0, corr = correxp2, beta0 = None, storage_mode = 'full', verbose = True, theta0 = 1e-1, thetaL = None, thetaU = None, optimizer = 'fmin_cobyla', random_start = 1, normalize = True, nugget = 10.*machine_epsilon): + """ + The Gaussian Process Model constructor. + + Parameters + ---------- + regr : lambda function, optional + A regression function returning an array of outputs of the linear regression functional basis. + (The number of observations m should be greater than the size p of this basis) + Default assumes a simple constant regression trend (see regpoly0). + + corr : lambda function, optional + A stationary autocorrelation function returning the autocorrelation between two points x and x'. + Default assumes a squared-exponential autocorrelation model (see correxp2). + + beta0 : double array_like, optional + Default assumes Universal Kriging (UK) so that the vector beta of regression weights is estimated + by Maximum Likelihood. Specifying beta0 overrides estimation and performs Ordinary Kriging (OK). + + storage_mode : string, optional + A string specifying whether the Cholesky decomposition of the correlation matrix should be + stored in the class (storage_mode = 'full') or not (storage_mode = 'light'). + Default assumes storage_mode = 'full', so that the Cholesky decomposition of the correlation matrix is stored. + This might be a useful parameter when one is not interested in the MSE and only plan to estimate the BLUP, + for which the correlation matrix is not required. + + verbose : boolean, optional + A boolean specifying the verbose level. + Default is verbose = True. + + theta0 : double array_like, optional + An array with shape (n_features, ) or (1, ). + The parameters in the autocorrelation model. + If thetaL and thetaU are also specified, theta0 is considered as the starting point + for the Maximum Likelihood Estimation of the best set of parameters. + Default assumes isotropic autocorrelation model with theta0 = 1e-1. + + thetaL : double array_like, optional + An array with shape matching theta0's. + Lower bound on the autocorrelation parameters for Maximum Likelihood Estimation. + Default is None, skips Maximum Likelihood Estimation and uses theta0. + + thetaU : double array_like, optional + An array with shape matching theta0's. + Upper bound on the autocorrelation parameters for Maximum Likelihood Estimation. + Default is None and skips Maximum Likelihood Estimation and uses theta0. + + normalize : boolean, optional + Design sites X and observations y are centered and reduced wrt means and standard deviations + estimated from the n_samples observations provided. + Default is true so that data is normalized to ease MLE. + + nugget : double, optional + Introduce a nugget effect to allow smooth predictions from noisy data. + Default assumes a nugget close to machine precision for the sake of robustness (nugget = 10.*machine_epsilon). + + optimizer : string, optional + A string specifying the optimization algorithm to be used ('fmin_cobyla' is the sole algorithm implemented yet). + Default uses 'fmin_cobyla' algorithm from scipy.optimize. + + random_start : int, optional + The number of times the Maximum Likelihood Estimation should be performed from a random starting point. + The first MLE always uses the specified starting point (theta0), + the next starting points are picked at random according to an exponential distribution (log-uniform in [thetaL, thetaU]). + Default does not use random starting point (random_start = 1). + + Returns + ------- + aGaussianProcessModel : self + A Gaussian Process Model object awaiting data to be fitted to. + """ + + self.regr = regr + self.corr = corr + self.beta0 = beta0 + + # Check storage mode + if storage_mode != 'full' and storage_mode != 'light': + if storage_mode == 'sparse': + raise NotImplementedError, "The 'sparse' storage mode is not supported yet. Please contribute!" + else: + raise ValueError, "Storage mode should either be 'full' or 'light'. Unknown storage mode: "+str(storage_mode) + else: + self.storage_mode = storage_mode + + self.verbose = verbose + + # Check correlation parameters + self.theta0 = np.atleast_2d(np.asanyarray(theta0, dtype=np.float)) + self.thetaL, self.thetaU = thetaL, thetaU + lth = self.theta0.size + + if self.thetaL is not None and self.thetaU is not None: + # Parameters optimization case + self.thetaL = np.atleast_2d(np.asanyarray(thetaL, dtype=np.float)) + self.thetaU = np.atleast_2d(np.asanyarray(thetaU, dtype=np.float)) + + if self.thetaL.size != lth or self.thetaU.size != lth: + raise ValueError, "theta0, thetaL and thetaU must have the same length" + if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL): + raise ValueError, "The bounds must satisfy O < thetaL <= thetaU" + + elif self.thetaL is None and self.thetaU is None: + # Given parameters case + if np.any(self.theta0 <= 0): + raise ValueError, "theta0 must be strictly positive" + + elif self.thetaL is None or self.thetaU is None: + # Error + raise Exception, "thetaL and thetaU should either be both or neither specified" + + # Store other parameters + self.normalize = normalize + self.nugget = nugget + self.optimizer = optimizer + self.random_start = int(random_start) + + def fit(self, X, y): + """ + The Gaussian Process Model fitting method. + + Parameters + ---------- + X : double array_like + An array with shape (n_samples, n_features) with the design sites at which observations were made. + + y : double array_like + An array with shape (n_features, ) with the observations of the scalar output to be predicted. + + Returns + ------- + aGaussianProcessModel : self + A fitted Gaussian Process Model object awaiting data to perform predictions. + """ + + # Force data to numpy.array type (from coding guidelines) + X = np.asanyarray(X, dtype=np.float) + y = np.asanyarray(y, dtype=np.float) + + # Check design sites & observations + n_samples_X = X.shape[0] + if X.ndim > 1: + n_features = X.shape[1] + else: + n_features = 1 + X = X.reshape(n_samples_X, n_features) + + n_samples_y = y.shape[0] + if y.ndim > 1: + raise NotImplementedError, "y has more than one dimension. This is not supported yet (scalar output prediction only). Please contribute!" + y = y.reshape(n_samples_y, 1) + + if n_samples_X != n_samples_y: + raise Exception, "X and y must have the same number of rows!" + else: + n_samples = n_samples_X + + # Normalize data or don't + if self.normalize: + mean_X = np.mean(X, axis=0) + std_X = np.sqrt(1./(n_samples-1.)*np.sum((X - np.mean(X,0))**2.,0)) #np.std(y, axis=0) + mean_y = np.mean(y, axis=0) + std_y = np.sqrt(1./(n_samples-1.)*np.sum((y - np.mean(y,0))**2.,0)) #np.std(y, axis=0) + std_X[std_X == 0.] = 1. + std_y[std_y == 0.] = 1. + else: + mean_X = np.array([0.]) + std_X = np.array([1.]) + mean_y = np.array([0.]) + std_y = np.array([1.]) + + X_ = (X - mean_X) / std_X + y_ = (y - mean_y) / std_y + + # Calculate matrix of distances D between samples + mzmax = n_samples*(n_samples-1)/2 + ij = np.zeros([mzmax, 2]) + D = np.zeros([mzmax, n_features]) + ll = np.array([-1]) + for k in range(n_samples-1): + ll = ll[-1] + 1 + range(n_samples-k-1) + ij[ll,:] = np.concatenate([[np.repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T + D[ll,:] = X_[k,:] - X_[(k+1):n_samples,:] + if (np.min(np.sum(np.abs(D),1)) == 0.) and (self.corr != corriid): + raise Exception, "Multiple X are not allowed" + + # Regression matrix and parameters + F = self.regr(X) + n_samples_F = F.shape[0] + if F.ndim > 1: + p = F.shape[1] + else: + p = 1 + if n_samples_F != n_samples: + raise Exception, "Number of rows in F and X do not match. Most likely something is wrong in the regression model." + if p > n_samples_F: + raise Exception, "Ordinary least squares problem is undetermined. n_samples=%d must be greater than the regression model size p=%d!" % (n_samples, p) + if self.beta0 is not None and (self.beta0.shape[0] != p or self.beta0.ndim > 1): + raise Exception, "Shapes of beta0 and F do not match." + + # Set attributes + self.X = X + self.y = y + self.X_ = X_ + self.y_ = y_ + self.D = D + self.ij = ij + self.F = F + self.X_sc = np.concatenate([[mean_X],[std_X]]) + self.y_sc = np.concatenate([[mean_y],[std_y]]) + + # Determine Gaussian Process Model parameters + if self.thetaL is not None and self.thetaU is not None: + # Maximum Likelihood Estimation of the parameters + if self.verbose: + print "Performing Maximum Likelihood Estimation of the autocorrelation parameters..." + self.theta, self.score_value, self.par = self.__arg_max_score__() + if np.isinf(self.score_value) : + raise Exception, "Bad parameter region. Try increasing upper bound" + else: + # Given parameters + if self.verbose: + print "Given autocorrelation parameters. Computing Gaussian Process Model parameters..." + self.theta = self.theta0 + self.score_value, self.par = self.score() + if np.isinf(self.score_value): + raise Exception, "Bad point. Try increasing theta0" + + if self.storage_mode == 'light': + # Delete heavy data (it will be computed again if required) + # (it is required only when MSE is wanted in self.predict) + if self.verbose: + print "Light storage mode specified. Flushing autocorrelation matrix..." + self.D = None + self.ij = None + self.F = None + self.par['C'] = None + self.par['Ft'] = None + self.par['G'] = None + + return self + + def predict(self, X, eval_MSE=False, batch_size=None): + """ + This function evaluates the Gaussian Process Model at x. + + Parameters + ---------- + X : array_like + An array with shape (n_eval, n_features) giving the point(s) at which the prediction(s) should be made. + eval_MSE : boolean, optional + A boolean specifying whether the Mean Squared Error should be evaluated or not. + Default assumes evalMSE = False and evaluates only the BLUP (mean prediction). + verbose : boolean, optional + A boolean specifying the verbose level. + Default is verbose = True. + batch_size : integer, optional + An integer giving the maximum number of points that can be evaluated simulatneously (depending on the available memory). + Default is None so that all given points are evaluated at the same time. + + Returns + ------- + y : array_like + An array with shape (n_eval, ) with the Best Linear Unbiased Prediction at x. + MSE : array_like, optional (if eval_MSE == True) + An array with shape (n_eval, ) with the Mean Squared Error at x. + """ + + # Check + if np.any(np.isnan(self.par['beta'])): + raise Exception, "Not a valid GaussianProcessModel. Try fitting it again with different parameters theta" + + # Check design & trial sites + X = np.asanyarray(X, dtype=np.float) + n_samples = self.X_.shape[0] + if self.X_.ndim > 1: + n_features = self.X_.shape[1] + else: + n = 1 + n_eval = X.shape[0] + if X.ndim > 1: + n_features_X = X.shape[1] + else: + n_features_X = 1 + X = X.reshape(n_eval, n_features_X) + + if n_features_X != n_features: + raise ValueError, "The number of features in X (X.shape[1] = %d) should match the sample size used for fit() which is %d." % (n_features_X, n_features) + + if batch_size is None: # No memory management (evaluates all given points in a single batch run) + + # Normalize trial sites + X_ = (X - self.X_sc[0,:]) / self.X_sc[1,:] + + # Initialize output + y = np.zeros(n_eval) + if eval_MSE: + MSE = np.zeros(n_eval) + + # Get distances to design sites + dx = np.zeros([n_eval*n_samples,n_features]) + kk = np.arange(n_samples).astype(int) + for k in range(n_eval): + dx[kk,:] = X_[k,:] - self.X_ + kk = kk + n_samples + + # Get regression function and correlation + f = self.regr(X_) + r = self.corr(self.theta, dx).reshape(n_eval, n_samples).T + + # Scaled predictor + y_ = np.matrix(f) * np.matrix(self.par['beta']) + (np.matrix(self.par['gamma']) * np.matrix(r)).T + + # Predictor + y = (self.y_sc[0,:] + self.y_sc[1,:] * np.array(y_)).reshape(n_eval) + + # Mean Squared Error + if eval_MSE: + par = self.par + if par['C'] is None: + # Light storage mode (need to recompute C, F, Ft and G) + if self.verbose: + print "This GaussianProcessModel used light storage mode at instanciation. Need to recompute autocorrelation matrix..." + score_value, par = self.score() + + rt = linalg.solve(np.matrix(par['C']), np.matrix(r)) + if self.beta0 is None: + # Universal Kriging + u = linalg.solve(np.matrix(-self.par['G'].T), np.matrix(self.par['Ft']).T * np.matrix(rt) - np.matrix(f).T) + else: + # Ordinary Kriging + u = 0. * y + + MSE = self.par['sigma2'] * (1. - col_sum(np.array(rt)**2.) + col_sum(np.array(u)**2.)).T + + # Mean Squared Error might be slightly negative depending on machine precision + # Force to zero! + MSE[MSE < 0.] = 0. + + return y, MSE + + else: + + return y + + else: # Memory management + + if type(batch_size) is not int or batch_size <= 0: + raise Exception, "batch_size must be a positive integer" + + if eval_MSE: + + y, MSE = np.zeros(n_eval), np.zeros(n_eval) + for k in range(n_eval/batch_size): + y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])], MSE[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \ + self.predict(X[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None) + + return y, MSE + + else: + + y = np.zeros(n_eval) + for k in range(n_eval/batch_size): + y[k*batch_size:min([(k+1)*batch_size+1, n_eval+1])] = \ + self.__call__(x[k*batch_size:min([(k+1)*batch_size+1, n_eval+1]),:], eval_MSE = eval_MSE, batch_size = None) + + return y + + + + def score(self, theta = None): + """ + This function determines the BLUP parameters and evaluates the score function for the given autocorrelation parameters theta. + Maximizing this score function wrt the autocorrelation parameters theta is equivalent to maximizing the likelihood of the + assumed joint Gaussian distribution of the observations y evaluated onto the design of experiments X. + + Parameters + ---------- + theta : array_like, optional + An array containing the autocorrelation parameters at which the Gaussian Process Model parameters should be determined. + Default uses the built-in autocorrelation parameters (ie theta = self.theta). + + Returns + ------- + par : dict + A dictionary containing the requested Gaussian Process Model parameters: + par['sigma2'] : Gaussian Process variance. + par['beta'] : Generalized least-squares regression weights for Universal Kriging or given beta0 for Ordinary Kriging. + par['gamma'] : Gaussian Process weights. + par['C'] : Cholesky decomposition of the correlation matrix [R]. + par['Ft'] : Solution of the linear equation system : [R] x Ft = F + par['G'] : QR decomposition of the matrix Ft. + par['detR'] : Determinant of the correlation matrix raised at power 1/n_samples. + score_value : double + The value of the score function associated to the given autocorrelation parameters theta. + """ + + if theta is None: + # Use built-in autocorrelation parameters + theta = self.theta + + score_value = -np.inf + par = { } + + # Retrieve data + n_samples = self.X_.shape[0] + D = self.D + ij = self.ij + F = self.F + + if D is None: + # Light storage mode (need to recompute D, ij and F) + if self.X_.ndim > 1: + n_features = self.X_.shape[1] + else: + n_features = 1 + mzmax = n_samples*(n_samples-1)/2 + ij = zeros([mzmax, n_features]) + D = zeros([mzmax, n_features]) + ll = array([-1]) + for k in range(n_samples-1): + ll = ll[-1]+1 + range(n_samples-k-1) + ij[ll,:] = np.concatenate([[repeat(k,n_samples-k-1,0)], [np.arange(k+1,n_samples).T]]).T + D[ll,:] = self.X_[k,:] - self.X_[(k+1):n_samples,:] + if (min(sum(abs(D),1)) == 0.) and (self.corr != corriid): + raise Exception, "Multiple X are not allowed" + F = self.regr(self.X_) + + # Set up R + r = self.corr(theta, D) + R = np.eye(n_samples) * (1. + self.nugget) + R[ij.astype(int)[:,0], ij.astype(int)[:,1]] = r + R[ij.astype(int)[:,1], ij.astype(int)[:,0]] = r + + # Cholesky decomposition of R + try: + C = linalg.cholesky(R, lower=True) + except linalg.LinAlgError: + return score_value, par + + # Get generalized least squares solution + Ft = linalg.solve(C, F) + try: + Q, G = linalg.qr(Ft, econ=True) + except: + #/usr/lib/python2.6/dist-packages/scipy/linalg/decomp.py:1177: DeprecationWarning: qr econ argument will be removed after scipy 0.7. The economy transform will then be available through the mode='economic' argument. + Q, G = linalg.qr(Ft, mode='economic') + pass + + rcondG = 1./(linalg.norm(G)*linalg.norm(linalg.inv(G))) + if rcondG < 1e-10: + # Check F + condF = linalg.norm(F)*linalg.norm(linalg.inv(F)) + if condF > 1e15: + raise Exception, "F is too ill conditioned. Poor combination of regression model and observations." + else: + # Ft is too ill conditioned + return score_value, par + + Yt = linalg.solve(C,self.y_) + if self.beta0 is None: + # Universal Kriging + beta = linalg.solve(G, np.matrix(Q).T*np.matrix(Yt)) + else: + # Ordinary Kriging + beta = np.array(self.beta0) + + rho = np.matrix(Yt) - np.matrix(Ft) * np.matrix(beta) + normalized_sigma2 = (np.array(rho)**2.).sum(axis=0)/n_samples + # The determinant of R is equal to the squared product of the diagonal elements of its Cholesky decomposition C + detR = (np.array(np.diag(C))**(2./n_samples)).prod() + score_value = - normalized_sigma2.sum() * detR + par = { 'sigma2':normalized_sigma2 * self.y_sc[1]**2, \ + 'beta':beta, \ + 'gamma':linalg.solve(C.T,rho).T, \ + 'C':C, \ + 'Ft':Ft, \ + 'G':G } + + return score_value, par + + def __arg_max_score__(self): + """ + This function estimates the autocorrelation parameters theta as the maximizer of the score function. + (Minimization of the opposite score function is used for convenience) + + Parameters + ---------- + self: All parameters are stored in the Gaussian Process Model object. + + Returns + ------- + optimal_theta : array_like + The best set of autocorrelation parameters (the maximizer of the score function). + optimal_score_value : double + The optimal score function value. + optimal_par : dict + The BLUP parameters associated to thetaOpt. + """ + + if self.verbose: + print "The chosen optimizer is: "+str(self.optimizer) + if self.random_start > 1: + print str(self.random_start)+" random starts are required." + + percent_completed = 0. + + if self.optimizer == 'fmin_cobyla': + + minus_score = lambda log10t: - self.score(theta = 10.**log10t)[0] + + constraints = [] + for i in range(self.theta0.size): + constraints.append(lambda log10t: log10t[i] - np.log10(self.thetaL[0,i])) + constraints.append(lambda log10t: np.log10(self.thetaU[0,i]) - log10t[i]) + + for k in range(self.random_start): + + if k == 0: + # Use specified starting point as first guess + theta0 = self.theta0 + else: + # Generate a random starting point log10-uniformly distributed between bounds + log10theta0 = np.log10(self.thetaL) + random.rand(self.theta0.size).reshape(self.theta0.shape) * np.log10(self.thetaU/self.thetaL) + theta0 = 10.**log10theta0 + + log10_optimal_theta = optimize.fmin_cobyla(minus_score, np.log10(theta0), constraints, iprint=0) + optimal_theta = 10.**log10_optimal_theta + optimal_minus_score_value, optimal_par = self.score(theta = optimal_theta) + optimal_score_value = - optimal_minus_score_value + + if k > 0: + if optimal_score_value > best_optimal_score_value: + best_optimal_score_value = optimal_score_value + best_optimal_par = optimal_par + best_optimal_theta = optimal_theta + else: + best_optimal_score_value = optimal_score_value + best_optimal_par = optimal_par + best_optimal_theta = optimal_theta + if self.verbose and self.random_start > 1: + if (20*k)/self.random_start > percent_completed: + percent_completed = (20*k)/self.random_start + print str(5*percent_completed)+"% completed" + + else: + + raise NotImplementedError, "This optimizer ('%s') is not implemented yet. Please contribute!" % self.optimizer + + return best_optimal_theta, best_optimal_score_value, best_optimal_par \ No newline at end of file diff --git a/scikits/learn/gpml/regression_models.py b/scikits/learn/gpml/regression_models.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf8df7d3255b80f998560d6d27608b73ed99392 --- /dev/null +++ b/scikits/learn/gpml/regression_models.py @@ -0,0 +1,84 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- + +################ +# Dependencies # +################ + +import numpy as np + +############################ +# Defaut regression models # +############################ + +def regpoly0(x): + """ + Zero order polynomial (constant, p = 1) regression model. + + regpoly0 : x --> f(x) = 1 + + Parameters + ---------- + x : array_like + An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. + + Returns + ------- + f : array_like + An array with shape (n_eval, p) with the values of the regression model. + """ + + x = np.asanyarray(x, dtype=np.float) + n_eval = x.shape[0] + f = np.ones([n_eval,1]) + + return f + +def regpoly1(x): + """ + First order polynomial (hyperplane, p = n) regression model. + + regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T + + Parameters + ---------- + x : array_like + An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. + + Returns + ------- + f : array_like + An array with shape (n_eval, p) with the values of the regression model. + """ + + x = np.asanyarray(x, dtype=np.float) + n_eval = x.shape[0] + f = np.hstack([np.ones([n_eval,1]), x]) + + return f + +def regpoly2(x): + """ + Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model. + + regpoly2 : x --> f(x) = [ x_i*x_j, (i,j) = 1,...,n ].T + i > j + + Parameters + ---------- + x : array_like + An array with shape (n_eval, n_features) giving the locations x at which the regression model should be evaluated. + + Returns + ------- + f : array_like + An array with shape (n_eval, p) with the values of the regression model. + """ + + x = np.asanyarray(x, dtype=np.float) + n_eval, n_features = x.shape + f = np.hstack([np.ones([n_eval,1]), x]) + for k in range(n_features): + f = np.hstack([f, x[:,k,np.newaxis] * x[:,k:]]) + + return f \ No newline at end of file diff --git a/scikits/learn/tests/test_gpml.py b/scikits/learn/tests/test_gpml.py new file mode 100644 index 0000000000000000000000000000000000000000..15453dd8eee7a641eca47a70e1b0b650911959b2 --- /dev/null +++ b/scikits/learn/tests/test_gpml.py @@ -0,0 +1,49 @@ +""" +Testing for Gaussian Process for Machine Learning module (scikits.learn.gpml) +""" + +import numpy as np +from numpy.testing import assert_array_equal, assert_array_almost_equal, \ + assert_almost_equal, assert_raises, assert_ + +from .. import gpml, datasets, cross_val, metrics + +diabetes = datasets.load_diabetes() + +def test_regression_1d_x_sinx(): + """ + MLE estimation of a Gaussian Process model with a squared exponential + correlation model (correxp2). Check random start optimization. + + Test the interpolating property. + """ + + f = lambda x: x * np.sin(x) + X = np.array([1., 3., 5., 6., 7., 8.]) + y = f(X) + gpm = gpml.GaussianProcessModel(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, random_start=10, verbose=False).fit(X, y) + y_pred, MSE = gpm.predict(X, eval_MSE=True) + + assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6)) + +def test_regression_diabetes(): + """ + MLE estimation of a Gaussian Process model with an anisotropic squared exponential + correlation model. + + Test the model using cross-validation module (quite time-consuming). + + Poor performance: Leave-one-out estimate of explained variance is about 0.5 + at best... To be investigated! + TODO: find a dataset that would prove GPML performance! + """ + + X, y = diabetes['data'], diabetes['target'] + + gpm = gpml.GaussianProcessModel(corr=gpml.correxp2, theta0=1e-4, nugget=1e-2, verbose=False).fit(X, y) + + gpm.thetaL, gpm.thetaU = None, None + score_func = lambda self, X_test, y_test: self.predict(X_test)[0] + y_pred = cross_val.cross_val_score(gpm, X, y=y, score_func=score_func, cv=cross_val.LeaveOneOut(y.size), n_jobs=1, verbose=0) + + assert Q2 > 0. \ No newline at end of file