diff --git a/examples/gaussian_process/plot_gp_diabetes_dataset.py b/examples/gaussian_process/plot_gp_diabetes_dataset.py index 21b85a0eb18b5e5cb883856ab1bfc839ab4a4d86..8416c954469c91828c7d39ca2b1b02a070bcd764 100644 --- a/examples/gaussian_process/plot_gp_diabetes_dataset.py +++ b/examples/gaussian_process/plot_gp_diabetes_dataset.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- """ -=============================================== -Gaussian Processes for Machine Learning example -=============================================== +============================================================= +Gaussian Processes regression example: the 'diabetes' dataset +============================================================= This example consists in fitting a Gaussian Process model onto the diabetes dataset. @@ -39,7 +39,6 @@ gp = GaussianProcess(theta0=1e-4, nugget=1e-2, verbose=False) # Fit the GP model to the data gp.fit(X, y) -gp.par['beta'] # Estimate the leave-one-out predictions using the cross_val module n_jobs = 2 # the distributing capacity available on the machine diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py index 17e2e63919dee277e682b9e5caa7d415f5323708..b2ed888e16044f9fd3c84fcd8625c28942fb4ed2 100644 --- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py +++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- """ -=============================================== -Gaussian Processes for Machine Learning example -=============================================== +============================================================================== +Gaussian Processes classification example: exploiting the probabilistic output +============================================================================== A two-dimensional regression exercise with a post-processing allowing for probabilistic classification thanks to the Gaussian property of the prediction. @@ -34,7 +34,7 @@ PHI = lambda x: Grv.cdf(x) PHIinv = lambda x: Grv.ppf(x) # A few constants -beta0 = 8 +lim = 8 b, kappa, e = 5, .5, .1 # The function to predict (classification will then consist in predicting @@ -62,8 +62,8 @@ gp.fit(X, Y) # Evaluate real function, the prediction and its MSE on a grid res = 50 -x1, x2 = np.meshgrid(np.linspace(- beta0, beta0, res), \ - np.linspace(- beta0, beta0, res)) +x1, x2 = np.meshgrid(np.linspace(- lim, lim, res), \ + np.linspace(- lim, lim, res)) xx = np.vstack([x1.reshape(x1.size), x2.reshape(x2.size)]).T YY = g(xx) @@ -87,7 +87,7 @@ pl.xlabel('$x_1$') pl.ylabel('$x_2$') cax = pl.imshow(np.flipud(PHI(- yy / sigma)), cmap=cm.gray_r, alpha=0.8, \ - extent=(- beta0, beta0, - beta0, beta0)) + extent=(- lim, lim, - lim, lim)) norm = pl.matplotlib.colors.Normalize(vmin=0., vmax=0.9) cb = pl.colorbar(cax, ticks=[0., 0.2, 0.4, 0.6, 0.8, 1.], norm=norm) cb.set_label('${\\rm \mathbb{P}}\left[\widehat{G}(\mathbf{x}) \leq 0\\right]$') diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py index 747e62691cc8eb2585652ee5bf96480f0aa85a4d..ededab866140102d788334b22dc61db9e41886fe 100644 --- a/examples/gaussian_process/plot_gp_regression.py +++ b/examples/gaussian_process/plot_gp_regression.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- """ -=============================================== -Gaussian Processes for Machine Learning example -=============================================== +================================================================= +Gaussian Processes regression example: basic introductory example +================================================================= A simple one-dimensional regression exercise with a cubic correlation model whose parameters are estimated using the maximum likelihood principle. @@ -18,7 +18,7 @@ confidence interval. # License: BSD style import numpy as np -from scikits.learn.gaussian_process import GaussianProcess, corrcubic +from scikits.learn.gaussian_process import GaussianProcess from matplotlib import pyplot as pl # Print the docstring @@ -38,7 +38,7 @@ Y = f(X).ravel() x = np.atleast_2d(np.linspace(0, 10, 1000)).T # Instanciate a Gaussian Process model -gp = GaussianProcess(corr=corrcubic, theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \ +gp = GaussianProcess(corr='cubic', theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \ random_start=100) # Fit to data using Maximum Likelihood Estimation of the parameters diff --git a/scikits/learn/gaussian_process/__init__.py b/scikits/learn/gaussian_process/__init__.py index f2bb68562f77ed4ec5b2bf93f6dec1b357b24703..e6ca5e2bfc809bc111b805cdd6d792cad520224f 100644 --- a/scikits/learn/gaussian_process/__init__.py +++ b/scikits/learn/gaussian_process/__init__.py @@ -1,15 +1,32 @@ #!/usr/bin/python # -*- coding: utf-8 -*- +# Author: Vincent Dubourg <vincent.dubourg@gmail.com> +# (mostly translation, see implementation details) +# License: BSD style + """ - This module contains a contribution to the scikit-learn project that - implements Gaussian Process based prediction (also known as Kriging). +A module that implements scalar Gaussian Process based prediction (also +known as Kriging). + +Contains +-------- +GaussianProcess: The main class of the module that implements the Gaussian + Process prediction theory. +regression_models: A submodule that contains the built-in regression models. +correlation_models: A submodule that contains the built-in correlation models. + +Implementation details +---------------------- +The presentation implementation is based on a translation of the DACE +Matlab toolbox. - The present implementation is based on a transliteration of the DACE - Matlab toolbox <http://www2.imm.dtu.dk/~hbn/dace/>. +See references: +[1] H.B. Nielsen, S.N. Lophaven, H. B. Nielsen and J. Sondergaard (2002). + DACE - A MATLAB Kriging Toolbox. + http://www2.imm.dtu.dk/~hbn/dace/dace.pdf """ from .gaussian_process import GaussianProcess -from .correlation import corrlin, corrcubic, correxp1, \ - correxp2, correxpg, corriid -from .regression import regpoly0, regpoly1, regpoly2 +from . import correlation_models +from . import regression_models diff --git a/scikits/learn/gaussian_process/correlation.py b/scikits/learn/gaussian_process/correlation_models.py similarity index 85% rename from scikits/learn/gaussian_process/correlation.py rename to scikits/learn/gaussian_process/correlation_models.py index 65dcc5b825b2ed21c144b0a407253c02afdfaa18..97b436ed9de6f07a2b8afdbe1b5d324e796e2866 100644 --- a/scikits/learn/gaussian_process/correlation.py +++ b/scikits/learn/gaussian_process/correlation_models.py @@ -1,6 +1,14 @@ #!/usr/bin/python # -*- coding: utf-8 -*- +# Author: Vincent Dubourg <vincent.dubourg@gmail.com> +# (mostly translation, see implementation details) +# License: BSD style + +""" +The built-in regression models submodule for the gaussian_process module. +""" + ################ # Dependencies # ################ @@ -13,14 +21,14 @@ import numpy as np ############################# -def correxp1(theta, d): +def absolute_exponential(theta, d): """ - Exponential autocorrelation model. + Absolute exponential autocorrelation model. (Ornstein-Uhlenbeck stochastic process) - n - correxp1 : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i| ) - i = 1 + n + theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i| ) + i = 1 Parameters ---------- @@ -58,14 +66,14 @@ def correxp1(theta, d): return r -def correxp2(theta, d): +def squared_exponential(theta, d): """ Squared exponential correlation model (Radial Basis Function). (Infinitely differentiable stochastic process, very smooth) - n - correxp2 : theta, dx --> r(theta, dx) = exp( sum - theta_i * (dx_i)^2 ) - i = 1 + n + theta, dx --> r(theta, dx) = exp( sum - theta_i * (dx_i)^2 ) + i = 1 Parameters ---------- @@ -103,15 +111,15 @@ def correxp2(theta, d): return r -def correxpg(theta, d): +def generalized_exponential(theta, d): """ Generalized exponential correlation model. (Useful when one does not know the smoothness of the function to be predicted.) - n - correxpg : theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i|^p ) - i = 1 + n + theta, dx --> r(theta, dx) = exp( sum - theta_i * |dx_i|^p ) + i = 1 Parameters ---------- @@ -152,15 +160,15 @@ def correxpg(theta, d): return r -def corriid(theta, d): +def pure_nugget(theta, d): """ Spatial independence correlation model (pure nugget). (Useful when one wants to solve an ordinary least squares problem!) - n - corriid : theta, dx --> r(theta, dx) = 1 if sum |dx_i| == 0 - i = 1 - 0 otherwise + n + theta, dx --> r(theta, dx) = 1 if sum |dx_i| == 0 + i = 1 + 0 otherwise Parameters ---------- @@ -191,11 +199,11 @@ def corriid(theta, d): return r -def corrcubic(theta, d): +def cubic(theta, d): """ Cubic correlation model. - corrcubic : theta, dx --> r(theta, dx) = + theta, dx --> r(theta, dx) = n prod max(0, 1 - 3(theta_j*d_ij)^2 + 2(theta_j*d_ij)^3) , i = 1,...,m j = 1 @@ -241,11 +249,11 @@ def corrcubic(theta, d): return r -def corrlin(theta, d): +def linear(theta, d): """ Linear correlation model. - corrlin : theta, dx --> r(theta, dx) = + theta, dx --> r(theta, dx) = n prod max(0, 1 - theta_j*d_ij) , i = 1,...,m j = 1 diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py index f5ba3e62fb829d0b573c956625ddc4f1a552a8bf..2622406ac3d525049fbe897ff57511ea320ff43e 100644 --- a/scikits/learn/gaussian_process/gaussian_process.py +++ b/scikits/learn/gaussian_process/gaussian_process.py @@ -1,6 +1,10 @@ #!/usr/bin/python # -*- coding: utf-8 -*- +# Author: Vincent Dubourg <vincent.dubourg@gmail.com> +# (mostly translation, see implementation details) +# License: BSD style + ################ # Dependencies # ################ @@ -8,9 +12,9 @@ import numpy as np from scipy import linalg, optimize, rand from ..base import BaseEstimator -from .regression import regpoly0 -from .correlation import correxp2, corriid -machine_epsilon = np.finfo(np.double).eps +from . import regression_models as regression +from . import correlation_models as correlation +MACHINE_EPSILON = np.finfo(np.double).eps if hasattr(linalg, 'solve_triangular'): # only in scipy since 0.9 solve_triangular = linalg.solve_triangular @@ -26,33 +30,104 @@ else: class GaussianProcess(BaseEstimator): """ - A class that implements scalar Gaussian Process based prediction (also - known as Kriging). + The Gaussian Process model class. + + Parameters + ---------- + regr : string or callable, optional + A regression function returning an array of outputs of the linear + regression functional basis. The number of observations n_samples + should be greater than the size p of this basis. + Default assumes a simple constant regression trend. + Here is the list of built-in regression models: + 'constant', 'linear', 'quadratic' + + corr : string or callable, optional + A stationary autocorrelation function returning the autocorrelation + between two points x and x'. + Default assumes a squared-exponential autocorrelation model. + Here is the list of built-in correlation models: + 'absolute_exponential', 'squared_exponential', + 'generalized_exponential', 'cubic', 'linear' + + beta0 : double array_like, optional + The regression weight vector to perform Ordinary Kriging (OK). + Default assumes Universal Kriging (UK) so that the vector beta of + regression weights is estimated using the maximum likelihood + principle. + + storage_mode : string, optional + A string specifying whether the Cholesky decomposition of the + correlation matrix should be stored in the class (storage_mode = + 'full') or not (storage_mode = 'light'). + Default assumes storage_mode = 'full', so that the + Cholesky decomposition of the correlation matrix is stored. + This might be a useful parameter when one is not interested in the + MSE and only plan to estimate the BLUP, for which the correlation + matrix is not required. + + verbose : boolean, optional + A boolean specifying the verbose level. + Default is verbose = False. + + theta0 : double array_like, optional + An array with shape (n_features, ) or (1, ). + The parameters in the autocorrelation model. + If thetaL and thetaU are also specified, theta0 is considered as + the starting point for the maximum likelihood rstimation of the + best set of parameters. + Default assumes isotropic autocorrelation model with theta0 = 1e-1. + + thetaL : double array_like, optional + An array with shape matching theta0's. + Lower bound on the autocorrelation parameters for maximum + likelihood estimation. + Default is None, so that it skips maximum likelihood estimation and + it uses theta0. + + thetaU : double array_like, optional + An array with shape matching theta0's. + Upper bound on the autocorrelation parameters for maximum + likelihood estimation. + Default is None, so that it skips maximum likelihood estimation and + it uses theta0. + + normalize : boolean, optional + Design sites X and observations y are centered and reduced wrt + means and standard deviations estimated from the n_samples + observations provided. + Default is normalize = True so that data is normalized to ease + maximum likelihood estimation. + + nugget : double, optional + Introduce a nugget effect to allow smooth predictions from noisy + data. + Default assumes a nugget close to machine precision for the sake of + robustness (nugget = 10. * MACHINE_EPSILON). + + optimizer : string, optional + A string specifying the optimization algorithm to be used + ('fmin_cobyla' is the sole algorithm implemented yet). + Default uses 'fmin_cobyla' algorithm from scipy.optimize. + + random_start : int, optional + The number of times the Maximum Likelihood Estimation should be + performed from a random starting point. + The first MLE always uses the specified starting point (theta0), + the next starting points are picked at random according to an + exponential distribution (log-uniform on [thetaL, thetaU]). + Default does not use random starting point (random_start = 1). Example ------- - import numpy as np - from scikits.learn.gaussian_process import GaussianProcess - import pylab as pl - - f = lambda x: x * np.sin(x) - X = np.array([1., 3., 5., 6., 7., 8.]) - Y = f(X) - gp = GaussianProcess(theta0=1e-1, thetaL=1e-3, thetaU=1e0, \ - random_start=100) - gp.fit(X, Y) - x = np.linspace(0, 10, 1000) - y_pred, MSE = gp.predict(x, eval_MSE=True) - - pl.show() - - Methods - ------- - fit(X, y) : self - Fit the model. - - predict(X) : array - Predict using the model. + >>> import numpy as np + >>> from scikits.learn.gaussian_process import GaussianProcess + >>> f = lambda x: x * np.sin(x) + >>> X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T + >>> Y = f(X).ravel() + >>> gp = GaussianProcess(theta0=1e-1, thetaL=1e-3, thetaU=1e0).fit(X, Y) + >>> x = np.atleast_2d(np.linspace(0, 10, 1000)).T + >>> y_pred, MSE = gp.predict(x, eval_MSE=True) Implementation details ---------------------- @@ -65,151 +140,42 @@ class GaussianProcess(BaseEstimator): http://www2.imm.dtu.dk/~hbn/dace/dace.pdf """ - def __init__(self, regr=regpoly0, corr=correxp2, beta0=None, \ - storage_mode='full', verbose=True, theta0=1e-1, \ - thetaL=None, thetaU=None, optimizer='fmin_cobyla', \ - random_start=1, normalize=True, \ - nugget=10. * machine_epsilon): - """ - The Gaussian Process model constructor. + _regression_types = { + 'constant': regression.constant, + 'linear': regression.linear, + 'quadratic': regression.quadratic} - Parameters - ---------- - regr : function, optional - A regression function returning an array of outputs of the linear - regression functional basis. The number of observations n_samples - should be greater than the size p of this basis. - Default assumes a simple constant regression trend (see regpoly0). - - corr : function, optional - A stationary autocorrelation function returning the autocorrelation - between two points x and x'. - Default assumes a squared-exponential autocorrelation model (see - correxp2). - - beta0 : double array_like, optional - The regression weight vector to perform Ordinary Kriging (OK). - Default assumes Universal Kriging (UK) so that the vector beta of - regression weights is estimated using the maximum likelihood - principle. - - storage_mode : string, optional - A string specifying whether the Cholesky decomposition of the - correlation matrix should be stored in the class (storage_mode = - 'full') or not (storage_mode = 'light'). - Default assumes storage_mode = 'full', so that the - Cholesky decomposition of the correlation matrix is stored. - This might be a useful parameter when one is not interested in the - MSE and only plan to estimate the BLUP, for which the correlation - matrix is not required. - - verbose : boolean, optional - A boolean specifying the verbose level. - Default is verbose = True. - - theta0 : double array_like, optional - An array with shape (n_features, ) or (1, ). - The parameters in the autocorrelation model. - If thetaL and thetaU are also specified, theta0 is considered as - the starting point for the maximum likelihood rstimation of the - best set of parameters. - Default assumes isotropic autocorrelation model with theta0 = 1e-1. - - thetaL : double array_like, optional - An array with shape matching theta0's. - Lower bound on the autocorrelation parameters for maximum - likelihood estimation. - Default is None, so that it skips maximum likelihood estimation and - it uses theta0. - - thetaU : double array_like, optional - An array with shape matching theta0's. - Upper bound on the autocorrelation parameters for maximum - likelihood estimation. - Default is None, so that it skips maximum likelihood estimation and - it uses theta0. - - normalize : boolean, optional - Design sites X and observations y are centered and reduced wrt - means and standard deviations estimated from the n_samples - observations provided. - Default is normalize = True so that data is normalized to ease - maximum likelihood estimation. - - nugget : double, optional - Introduce a nugget effect to allow smooth predictions from noisy - data. - Default assumes a nugget close to machine precision for the sake of - robustness (nugget = 10.*machine_epsilon). - - optimizer : string, optional - A string specifying the optimization algorithm to be used - ('fmin_cobyla' is the sole algorithm implemented yet). - Default uses 'fmin_cobyla' algorithm from scipy.optimize. - - random_start : int, optional - The number of times the Maximum Likelihood Estimation should be - performed from a random starting point. - The first MLE always uses the specified starting point (theta0), - the next starting points are picked at random according to an - exponential distribution (log-uniform on [thetaL, thetaU]). - Default does not use random starting point (random_start = 1). + _correlation_types = { + 'absolute_exponential': correlation.absolute_exponential, + 'squared_exponential': correlation.squared_exponential, + 'generalized_exponential': correlation.generalized_exponential, + 'cubic': correlation.cubic, + 'linear': correlation.linear} - Returns - ------- - gp : self - A Gaussian Process model object awaiting data to be fitted to. - """ + _optimizer_types = [ + 'fmin_cobyla'] + + def __init__(self, regr='constant', corr='squared_exponential', beta0=None, + storage_mode='full', verbose=False, theta0=1e-1, + thetaL=None, thetaU=None, optimizer='fmin_cobyla', + random_start=1, normalize=True, + nugget=10. * MACHINE_EPSILON): self.regr = regr self.corr = corr self.beta0 = beta0 - - # Check storage mode - if storage_mode != 'full' and storage_mode != 'light': - if storage_mode == 'sparse': - raise NotImplementedError("The 'sparse' storage mode is not " \ - + "supported yet. Please contribute!") - else: - raise ValueError("Storage mode should either be 'full' or " \ - + "'light'. Unknown storage mode: " + str(storage_mode)) - else: - self.storage_mode = storage_mode - + self.storage_mode = storage_mode self.verbose = verbose - - # Check correlation parameters - self.theta0 = np.atleast_2d(theta0) - self.thetaL, self.thetaU = thetaL, thetaU - lth = self.theta0.size - - if self.thetaL is not None and self.thetaU is not None: - # Parameters optimization case - self.thetaL = np.atleast_2d(thetaL) - self.thetaU = np.atleast_2d(thetaU) - - if self.thetaL.size != lth or self.thetaU.size != lth: - raise ValueError("theta0, thetaL and thetaU must have the " \ - + "same length") - if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL): - raise ValueError("The bounds must satisfy O < thetaL <= " \ - + "thetaU") - - elif self.thetaL is None and self.thetaU is None: - # Given parameters case - if np.any(self.theta0 <= 0): - raise ValueError("theta0 must be strictly positive") - - elif self.thetaL is None or self.thetaU is None: - # Error - raise Exception("thetaL and thetaU should either be both or " \ - + "neither specified") - - # Store other parameters + self.theta0 = theta0 + self.thetaL = thetaL + self.thetaU = thetaU self.normalize = normalize self.nugget = nugget self.optimizer = optimizer - self.random_start = int(random_start) + self.random_start = random_start + + # Run input checks + self._check_params() def fit(self, X, y): """ @@ -232,6 +198,9 @@ class GaussianProcess(BaseEstimator): predictions. """ + # Run input checks + self._check_params() + # Force data to 2D numpy.array X = np.atleast_2d(X) y = np.asanyarray(y)[:, np.newaxis] @@ -241,7 +210,7 @@ class GaussianProcess(BaseEstimator): n_samples_y = y.shape[0] if n_samples_X != n_samples_y: - raise Exception("X and y must have the same number of rows!") + raise Exception("X and y must have the same number of rows.") else: n_samples = n_samples_X @@ -269,10 +238,11 @@ class GaussianProcess(BaseEstimator): ll = np.array([-1]) for k in range(n_samples-1): ll = ll[-1] + 1 + range(n_samples - k - 1) - ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \ + ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], [np.arange(k + 1, n_samples).T]]).T D[ll] = X[k] - X[(k + 1):n_samples] - if np.min(np.sum(np.abs(D), 1)) == 0. and self.corr != corriid: + if np.min(np.sum(np.abs(D), 1)) == 0. \ + and self.corr != correlation.pure_nugget: raise Exception("Multiple X are not allowed") # Regression matrix and parameters @@ -283,16 +253,18 @@ class GaussianProcess(BaseEstimator): else: p = 1 if n_samples_F != n_samples: - raise Exception("Number of rows in F and X do not match. Most " \ - + "likely something is going wrong with the " \ + raise Exception("Number of rows in F and X do not match. Most " + + "likely something is going wrong with the " + "regression model.") if p > n_samples_F: - raise Exception("Ordinary least squares problem is undetermined " \ - + "n_samples=%d must be greater than the " \ - + "regression model size p=%d!" % (n_samples, p)) - if self.beta0 is not None \ - and (self.beta0.shape[0] != p or self.beta0.ndim > 1): - raise Exception("Shapes of beta0 and F do not match.") + raise Exception(("Ordinary least squares problem is undetermined " + + "n_samples=%d must be greater than the " + + "regression model size p=%d.") % (n_samples, p)) + if self.beta0 is not None: + if self.beta0.shape[0] != p: + import pdb + pdb.set_trace() + raise Exception("Shapes of beta0 and F do not match.") # Set attributes self.X = X @@ -307,37 +279,44 @@ class GaussianProcess(BaseEstimator): if self.thetaL is not None and self.thetaU is not None: # Maximum Likelihood Estimation of the parameters if self.verbose: - print "Performing Maximum Likelihood Estimation of the " \ - + "autocorrelation parameters..." - self.theta, self.reduced_likelihood_function_value, self.par = \ + print("Performing Maximum Likelihood Estimation of the " + + "autocorrelation parameters...") + self.theta, self.reduced_likelihood_function_value, par = \ self.arg_max_reduced_likelihood_function() if np.isinf(self.reduced_likelihood_function_value): - raise Exception("Bad parameter region. " \ + raise Exception("Bad parameter region. " + "Try increasing upper bound") else: # Given parameters if self.verbose: - print "Given autocorrelation parameters. " \ - + "Computing Gaussian Process model parameters..." + print("Given autocorrelation parameters. " + + "Computing Gaussian Process model parameters...") self.theta = self.theta0 - self.reduced_likelihood_function_value, self.par = \ + self.reduced_likelihood_function_value, par = \ self.reduced_likelihood_function() if np.isinf(self.reduced_likelihood_function_value): - raise Exception("Bad point. Try increasing theta0") + raise Exception("Bad point. Try increasing theta0.") + + self.beta = par['beta'] + self.gamma = par['gamma'] + self.sigma2 = par['sigma2'] + self.C = par['C'] + self.Ft = par['Ft'] + self.G = par['G'] if self.storage_mode == 'light': # Delete heavy data (it will be computed again if required) # (it is required only when MSE is wanted in self.predict) if self.verbose: - print "Light storage mode specified. " \ - + "Flushing autocorrelation matrix..." + print("Light storage mode specified. " + + "Flushing autocorrelation matrix...") self.D = None self.ij = None self.F = None - self.par['C'] = None - self.par['Ft'] = None - self.par['G'] = None + self.C = None + self.Ft = None + self.G = None return self @@ -357,10 +336,6 @@ class GaussianProcess(BaseEstimator): Default assumes evalMSE = False and evaluates only the BLUP (mean prediction). - verbose : boolean, optional - A boolean specifying the verbose level. - Default is verbose = True. - batch_size : integer, optional An integer giving the maximum number of points that can be evaluated simulatneously (depending on the available memory). @@ -377,11 +352,8 @@ class GaussianProcess(BaseEstimator): An array with shape (n_eval, ) with the Mean Squared Error at x. """ - # Check itself - if np.any(np.isnan(self.par['beta'])): - raise Exception("Not a valid GaussianProcess. " \ - + "Try fitting it again with different parameters " \ - + "theta.") + # Run input checks + self._check_params() # Check design & trial sites X = np.atleast_2d(X) @@ -389,9 +361,9 @@ class GaussianProcess(BaseEstimator): n_samples, n_features = self.X.shape if n_features_X != n_features: - raise ValueError("The number of features in X (X.shape[1] = %d) " \ - % n_features_X + "should match the sample size " \ - + "used for fit() which is %d." % n_features) + raise ValueError(("The number of features in X (X.shape[1] = %d) " + + "should match the sample size used for fit() " + + "which is %d.") % (n_features_X, n_features)) if batch_size is None: # No memory management @@ -417,35 +389,38 @@ class GaussianProcess(BaseEstimator): r = self.corr(self.theta, dx).reshape(n_eval, n_samples) # Scaled predictor - y_ = np.dot(f, self.par['beta']) \ - + np.dot(r, self.par['gamma']) + y_ = np.dot(f, self.beta) \ + + np.dot(r, self.gamma) # Predictor y = (self.y_sc[0] + self.y_sc[1] * y_).ravel() # Mean Squared Error if eval_MSE: - par = self.par - if par['C'] is None: + C = self.C + if C is None: # Light storage mode (need to recompute C, F, Ft and G) if self.verbose: - print "This GaussianProcess used light storage mode " \ - + "at instanciation. Need to recompute " \ - + "autocorrelation matrix..." + print("This GaussianProcess used 'light' storage mode " + + "at instanciation. Need to recompute " + + "autocorrelation matrix...") reduced_likelihood_function_value, par = \ self.reduced_likelihood_function() + self.C = par['C'] + self.Ft = par['Ft'] + self.G = par['G'] - rt = solve_triangular(par['C'], r.T) + rt = solve_triangular(C, r.T) if self.beta0 is None: # Universal Kriging - u = solve_triangular(self.par['G'].T, \ - np.dot(self.par['Ft'].T, rt) - f.T) + u = solve_triangular(self.G.T, + np.dot(self.Ft.T, rt) - f.T) else: # Ordinary Kriging u = np.zeros(y.shape) - MSE = self.par['sigma2'] * (1. - (rt ** 2.).sum(axis=0) \ - + (u ** 2.).sum(axis=0)) + MSE = self.sigma2 * (1. - (rt ** 2.).sum(axis=0) + + (u ** 2.).sum(axis=0)) # Mean Squared Error might be slightly negative depending on # machine precision: force to zero! @@ -465,29 +440,25 @@ class GaussianProcess(BaseEstimator): if eval_MSE: - y, MSE = np.array([]), np.array([]) + y, MSE = np.zeros(n_eval), np.zeros(n_eval) for k in range(n_eval / batch_size): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) - X_batch = X[batch_from:batch_to][:] - y_batch, MSE_batch = \ - self.predict(X_batch, eval_MSE=eval_MSE, \ - batch_size=None) - y.append(y_batch) - MSE.append(MSE_batch) + y[batch_from:batch_to], MSE[batch_from:batch_to] = \ + self.predict(X[batch_from:batch_to][:], + eval_MSE=eval_MSE, batch_size=None) return y, MSE else: - y = np.array([]) + y = np.zeros(n_eval) for k in range(n_eval / batch_size): batch_from = k * batch_size batch_to = min([(k + 1) * batch_size + 1, n_eval + 1]) - X_batch = X[batch_from:batch_to][:] - y_batch = \ - self.predict(X_batch, eval_MSE=eval_MSE, \ - batch_size=None) + y[batch_from:batch_to] = \ + self.predict(X[batch_from:batch_to][:], + eval_MSE=eval_MSE, batch_size=None) return y @@ -527,8 +498,6 @@ class GaussianProcess(BaseEstimator): par['C'] : Cholesky decomposition of the correlation matrix [R]. par['Ft'] : Solution of the linear equation system : [R] x Ft = F par['G'] : QR decomposition of the matrix Ft. - par['detR'] : Determinant of the correlation matrix raised at power - 1/n_samples. """ if theta is None: @@ -558,10 +527,11 @@ class GaussianProcess(BaseEstimator): for k in range(n_samples-1): ll = ll[-1] + 1 + range(n_samples - k - 1) ij[ll] = \ - np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], \ + np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], [np.arange(k + 1, n_samples).T]]).T D[ll] = self.X[k] - self.X[(k + 1):n_samples] - if min(sum(abs(D), 1)) == 0. and self.corr != corriid: + if min(sum(abs(D), 1)) == 0. \ + and self.corr != correlation.pure_nugget: raise Exception("Multiple X are not allowed") F = self.regr(self.X) @@ -596,7 +566,7 @@ class GaussianProcess(BaseEstimator): sv = linalg.svd(F, compute_uv=False) condF = sv[0] / sv[-1] if condF > 1e15: - raise Exception("F is too ill conditioned. Poor combination " \ + raise Exception("F is too ill conditioned. Poor combination " + "of regression model and observations.") else: # Ft is too ill conditioned, get out (try different theta) @@ -690,7 +660,7 @@ class GaussianProcess(BaseEstimator): # Run Cobyla log10_optimal_theta = \ - optimize.fmin_cobyla(minus_reduced_likelihood_function, \ + optimize.fmin_cobyla(minus_reduced_likelihood_function, np.log10(theta0), constraints, iprint=0) optimal_theta = 10. ** log10_optimal_theta @@ -715,15 +685,15 @@ class GaussianProcess(BaseEstimator): else: - raise NotImplementedError("This optimizer ('%s') is not " \ - + "implemented yet. Please contribute!" \ + raise NotImplementedError("This optimizer ('%s') is not " + + "implemented yet. Please contribute!" % self.optimizer) return best_optimal_theta, best_optimal_rlf_value, best_optimal_par def score(self, X_test, y_test): """ - This score function returns the deviations of the Gaussian Process + This score function returns the mean deviation of the Gaussian Process model evaluated onto a test dataset. Parameters @@ -736,9 +706,81 @@ class GaussianProcess(BaseEstimator): Returns ------- - score_values : array_like - The deviations between the prediction and the targets: - y_pred - y_test. + score_value : array_like + The mean of the deviations between the prediction and the targets: + mean(y_pred - y_test). """ - return np.ravel(self.predict(X_test, eval_MSE=False)) - y_test + return (self.predict(X_test, eval_MSE=False) - y_test).mean() + + def _check_params(self): + + # Check regression model + if not callable(self.regr): + if self.regr in self._regression_types: + self.regr = self._regression_types[self.regr] + else: + raise ValueError(("regr should be one of %s or callable, " + + "%s was given.") + % (self._regression_types.keys(), self.regr)) + + # Check regression weights if given (Ordinary Kriging) + if self.beta0 is not None: + self.beta0 = np.atleast_2d(self.beta0) + if self.beta0.shape[1] != 1: + # Force to column vector + self.beta0 = self.beta0.T + + # Check correlation model + if not callable(self.corr): + if self.corr in self._correlation_types: + self.corr = self._correlation_types[self.corr] + else: + raise ValueError(("corr should be one of %s or callable, " + + "%s was given.") + % (self._correlation_types.keys(), self.corr)) + + # Check storage mode + if self.storage_mode != 'full' and self.storage_mode != 'light': + raise ValueError("Storage mode should either be 'full' or " + + "'light', %s was given." % self.storage_mode) + + # Check correlation parameters + self.theta0 = np.atleast_2d(self.theta0) + lth = self.theta0.size + + if self.thetaL is not None and self.thetaU is not None: + self.thetaL = np.atleast_2d(self.thetaL) + self.thetaU = np.atleast_2d(self.thetaU) + if self.thetaL.size != lth or self.thetaU.size != lth: + raise ValueError("theta0, thetaL and thetaU must have the " + + "same length.") + if np.any(self.thetaL <= 0) or np.any(self.thetaU < self.thetaL): + raise ValueError("The bounds must satisfy O < thetaL <= " + + "thetaU.") + + elif self.thetaL is None and self.thetaU is None: + if np.any(self.theta0 <= 0): + raise ValueError("theta0 must be strictly positive.") + + elif self.thetaL is None or self.thetaU is None: + raise ValueError("thetaL and thetaU should either be both or " + + "neither specified.") + + # Force verbose type to bool + self.verbose = bool(self.verbose) + + # Force normalize type to bool + self.normalize = bool(self.normalize) + + # Check nugget value + if self.nugget < 0.: + raise ValueError("nugget must be positive or zero.") + + # Check optimizer + if not self.optimizer in self._optimizer_types: + raise ValueError("optimizer should be one of %s" + % self._optimizer_types) + + # Force random_start type to int + self.random_start = int(self.random_start) diff --git a/scikits/learn/gaussian_process/regression.py b/scikits/learn/gaussian_process/regression_models.py similarity index 73% rename from scikits/learn/gaussian_process/regression.py rename to scikits/learn/gaussian_process/regression_models.py index 58bf14948bcd400088ef49b325f4a5e7946de172..68c5a196f77e92a88fcc5e96eac30b289b8cf9d0 100644 --- a/scikits/learn/gaussian_process/regression.py +++ b/scikits/learn/gaussian_process/regression_models.py @@ -1,6 +1,14 @@ #!/usr/bin/python # -*- coding: utf-8 -*- +# Author: Vincent Dubourg <vincent.dubourg@gmail.com> +# (mostly translation, see implementation details) +# License: BSD style + +""" +The built-in regression models submodule for the gaussian_process module. +""" + ################ # Dependencies # ################ @@ -13,11 +21,11 @@ import numpy as np ############################ -def regpoly0(x): +def constant(x): """ Zero order polynomial (constant, p = 1) regression model. - regpoly0 : x --> f(x) = 1 + x --> f(x) = 1 Parameters ---------- @@ -39,11 +47,11 @@ def regpoly0(x): return f -def regpoly1(x): +def linear(x): """ - First order polynomial (hyperplane, p = n) regression model. + First order polynomial (linear, p = n+1) regression model. - regpoly1 : x --> f(x) = [ x_1, ..., x_n ].T + x --> f(x) = [ 1, x_1, ..., x_n ].T Parameters ---------- @@ -65,12 +73,12 @@ def regpoly1(x): return f -def regpoly2(x): +def quadratic(x): """ - Second order polynomial (hyperparaboloid, p = n*(n-1)/2) regression model. + Second order polynomial (quadratic, p = n*(n-1)/2+n+1) regression model. - regpoly2 : x --> f(x) = [ x_i*x_j, (i,j) = 1,...,n ].T - i > j + x --> f(x) = [ 1, { x_i, i = 1,...,n }, { x_i * x_j, (i,j) = 1,...,n } ].T + i > j Parameters ---------- diff --git a/scikits/learn/gaussian_process/tests/__init__.py b/scikits/learn/gaussian_process/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scikits/learn/gaussian_process/tests/test_gaussian_process.py b/scikits/learn/gaussian_process/tests/test_gaussian_process.py new file mode 100644 index 0000000000000000000000000000000000000000..2f71070fa999dbd11f61ee9a5fcf012265b58d10 --- /dev/null +++ b/scikits/learn/gaussian_process/tests/test_gaussian_process.py @@ -0,0 +1,87 @@ +""" +Testing for Gaussian Process module (scikits.learn.gaussian_process) +""" + +# Author: Vincent Dubourg <vincent.dubourg@gmail.com> +# License: BSD style + +import numpy as np + +from .. import GaussianProcess +from .. import regression_models as regression +from .. import correlation_models as correlation + + +def test_1d(regr=regression.constant, corr=correlation.squared_exponential, + random_start=10, beta0=None): + """ + MLE estimation of a one-dimensional Gaussian Process model. + Check random start optimization. + + Test the interpolating property. + """ + + f = lambda x: x * np.sin(x) + X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T + y = f(X).ravel() + gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0, + theta0=1e-2, thetaL=1e-4, thetaU=1e-1, + random_start=random_start, verbose=False).fit(X, y) + y_pred, MSE = gp.predict(X, eval_MSE=True) + + assert np.allclose(y_pred, y) and np.allclose(MSE, 0.) + + +def test_2d(regr=regression.constant, corr=correlation.squared_exponential, + random_start=10, beta0=None): + """ + MLE estimation of a two-dimensional Gaussian Process model accounting for + anisotropy. Check random start optimization. + + Test the interpolating property. + """ + + b, kappa, e = 5., .5, .1 + g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2. + X = np.array([[-4.61611719, -6.00099547], + [4.10469096, 5.32782448], + [0.00000000, -0.50000000], + [-6.17289014, -4.6984743], + [1.3109306, -6.93271427], + [-5.03823144, 3.10584743], + [-2.87600388, 6.74310541], + [5.21301203, 4.26386883]]) + y = g(X).ravel() + gp = GaussianProcess(regr=regr, corr=corr, beta0=beta0, + theta0=[1e-2] * 2, thetaL=[1e-4] * 2, + thetaU=[1e-1] * 2, + random_start=random_start, verbose=False) + gp.fit(X, y) + y_pred, MSE = gp.predict(X, eval_MSE=True) + + assert np.allclose(y_pred, y) and np.allclose(MSE, 0.) + + +def test_more_builtin_correlation_models(random_start=1): + """ + Repeat test_1d and test_2d for several built-in correlation + models specified as strings. + """ + all_corr = ['absolute_exponential', 'squared_exponential', 'cubic', + 'linear'] + + for corr in all_corr: + test_1d(regr='constant', corr=corr, random_start=random_start) + test_2d(regr='constant', corr=corr, random_start=random_start) + + +def test_ordinary_kriging(): + """ + Repeat test_1d and test_2d with given regression weights (beta0) for + different regression models (Ordinary Kriging). + """ + + test_1d(regr='linear', beta0=[0., 0.5]) + test_1d(regr='quadratic', beta0=[0., 0.5, 0.5]) + test_2d(regr='linear', beta0=[0., 0.5, 0.5]) + test_2d(regr='quadratic', beta0=[0., 0.5, 0.5, 0.5, 0.5, 0.5]) diff --git a/scikits/learn/tests/test_gaussian_process.py b/scikits/learn/tests/test_gaussian_process.py deleted file mode 100644 index 99e81556979a7f219f34b71bf8a967164754a385..0000000000000000000000000000000000000000 --- a/scikits/learn/tests/test_gaussian_process.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Testing for Gaussian Process module (scikits.learn.gaussian_process) -""" - -import numpy as np -from numpy.testing import assert_array_equal, assert_array_almost_equal, \ - assert_almost_equal, assert_raises, assert_ - -from ..gaussian_process import GaussianProcess - - -def test_regression_1d_x_sinx(): - """ - MLE estimation of a Gaussian Process model with a squared exponential - correlation model (correxp2). Check random start optimization. - - Test the interpolating property. - """ - - f = lambda x: x * np.sin(x) - X = np.array([1., 3., 5., 6., 7., 8.]) - y = f(X) - gp = GaussianProcess(theta0=1e-2, thetaL=1e-4, thetaU=1e-1, \ - random_start=10, verbose=False).fit(X, y) - y_pred, MSE = gp.predict(X, eval_MSE=True) - - assert (np.all(np.abs((y_pred - y) / y) < 1e-6) and np.all(MSE < 1e-6))