diff --git a/examples/gaussian_process/plot_gp_diabetes_dataset.py b/examples/gaussian_process/plot_gp_diabetes_dataset.py index dafc0867f258fb02a3efda1b2d74d299197ecffc..219915c6c886cf0b8145a8a7a89a361a2c811e8a 100644 --- a/examples/gaussian_process/plot_gp_diabetes_dataset.py +++ b/examples/gaussian_process/plot_gp_diabetes_dataset.py @@ -2,69 +2,52 @@ # -*- coding: utf-8 -*- """ -========================================================================= +======================================================================== Gaussian Processes regression: goodness-of-fit on the 'diabetes' dataset -========================================================================= +======================================================================== This example consists in fitting a Gaussian Process model onto the diabetes dataset. -WARNING: This is quite time consuming (about 2 minutes overall runtime). -The corelation parameters are given in order to maximize the generalization -capacity of the model. We assumed an anisotropic squared exponential -correlation model with a constant regression model. We also used a -nugget = 1e-2 in order to account for the (strong) noise in the targets. +The correlation parameters are determined by means of maximum likelihood +estimation (MLE). An anisotropic squared exponential correlation model with a +constant regression model are assumed. We also used a nugget = 1e-2 in order to +account for the (strong) noise in the targets. -The figure is a goodness-of-fit plot obtained using leave-one-out predictions -of the Gaussian Process model. Based on these predictions, we compute an -explained variance error (Q2). +We compute then compute a cross-validation estimate of the coefficient of +determination (R2) without reperforming MLE, using the set of correlation +parameters found on the whole dataset. """ +print __doc__ # Author: Vincent Dubourg <vincent.dubourg@gmail.com> # License: BSD style -from scikits.learn import datasets, cross_val, metrics +import numpy as np +from scikits.learn import datasets from scikits.learn.gaussian_process import GaussianProcess -from matplotlib import pyplot as pl - -# Print the docstring -print __doc__ +from scikits.learn.cross_val import cross_val_score, KFold +from scikits.learn.metrics import r2_score # Load the dataset from scikits' data sets diabetes = datasets.load_diabetes() -X, y = diabetes['data'], diabetes['target'] +X, y = diabetes.data, diabetes.target # Instanciate a GP model gp = GaussianProcess(regr='constant', corr='absolute_exponential', theta0=[1e-4] * 10, thetaL=[1e-12] * 10, - thetaU=[1e-2] * 10, nugget=1e-2, optimizer='Welch', - verbose=False) + thetaU=[1e-2] * 10, nugget=1e-2, optimizer='Welch') -# Fit the GP model to the data +# Fit the GP model to the data performing maximum likelihood estimation gp.fit(X, y) -gp.theta0 = gp.theta -gp.thetaL = None -gp.thetaU = None -gp.verbose = False - -# Estimate the leave-one-out predictions using the cross_val module -n_jobs = 2 # the distributing capacity available on the machine -y_pred = y + cross_val.cross_val_score(gp, X, y=y, - cv=cross_val.LeaveOneOut(y.size), - n_jobs=n_jobs, - ).ravel() -# Compute the empirical explained variance -Q2 = metrics.explained_variance_score(y, y_pred) +# Deactivate maximum likelihood estimation for the cross-validation loop +gp.theta0 = gp.theta # Given correlation parameter = MLE +gp.thetaL, gp.thetaU = None, None # None bounds deactivate MLE -# Goodness-of-fit plot -pl.figure() -pl.title('Goodness-of-fit plot (Q2 = %1.2e)' % Q2) -pl.plot(y, y_pred, 'r.', label='Leave-one-out') -pl.plot(y, gp.predict(X), 'k.', label='Whole dataset (nugget=1e-2)') -pl.plot([y.min(), y.max()], [y.min(), y.max()], 'k--') -pl.xlabel('Observations') -pl.ylabel('Predictions') -pl.legend(loc='upper left') -pl.axis('tight') -pl.show() +# Perform a cross-validation estimate of the coefficient of determination using +# the cross_val module using all CPUs available on the machine +K = 20 # folds +R2 = cross_val_score(gp, X, y=y, cv=KFold(y.size, K), n_jobs=-1).mean() +print("The %d-Folds estimate of the coefficient of determination is R2 = %s" + % (K, R2)) diff --git a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py index 5c1aa06bf001e88c5e7e123143e4b3aae817414d..329f2cf9322464ca16d37451b1d376e684ecf3c2 100644 --- a/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py +++ b/examples/gaussian_process/plot_gp_probabilistic_classification_after_regression.py @@ -14,6 +14,7 @@ respect to the remaining uncertainty in the prediction. The red and blue lines corresponds to the 95% confidence interval on the prediction of the zero level set. """ +print __doc__ # Author: Vincent Dubourg <vincent.dubourg@gmail.com> # License: BSD style @@ -24,22 +25,19 @@ from scikits.learn.gaussian_process import GaussianProcess from matplotlib import pyplot as pl from matplotlib import cm -# Print the docstring -print __doc__ - # Standard normal distribution functions -Grv = stats.distributions.norm() -phi = lambda x: Grv.pdf(x) -PHI = lambda x: Grv.cdf(x) -PHIinv = lambda x: Grv.ppf(x) +phi = stats.distributions.norm().pdf +PHI = stats.distributions.norm().cdf +PHIinv = stats.distributions.norm().ppf # A few constants lim = 8 -b, kappa, e = 5, .5, .1 -# The function to predict (classification will then consist in predicting -# wheter g(x) <= 0 or not) -g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2. + +def g(x): + """The function to predict (classification will then consist in predicting + whether g(x) <= 0 or not)""" + return 5. - x[:, 1] - .5 * x[:, 0] ** 2. # Design of experiments X = np.array([[-4.61611719, -6.00099547], diff --git a/examples/gaussian_process/plot_gp_regression.py b/examples/gaussian_process/plot_gp_regression.py index 90bf120ca998fba18dc5fc5def2d9388453940b6..66c27fe493080be1b48f0e714ebc534fbb6e7019 100644 --- a/examples/gaussian_process/plot_gp_regression.py +++ b/examples/gaussian_process/plot_gp_regression.py @@ -2,9 +2,9 @@ # -*- coding: utf-8 -*- """ -================================================================= +========================================================= Gaussian Processes regression: basic introductory example -================================================================= +========================================================= A simple one-dimensional regression exercise with a cubic correlation model whose parameters are estimated using the maximum likelihood principle. @@ -13,6 +13,7 @@ The figure illustrates the interpolating property of the Gaussian Process model as well as its probabilistic nature in the form of a pointwise 95% confidence interval. """ +print __doc__ # Author: Vincent Dubourg <vincent.dubourg@gmail.com> # License: BSD style @@ -21,11 +22,10 @@ import numpy as np from scikits.learn.gaussian_process import GaussianProcess from matplotlib import pyplot as pl -# Print the docstring -print __doc__ -# The function to predict -f = lambda x: x * np.sin(x) +def f(x): + """The function to predict.""" + return x * np.sin(x) # The design of experiments X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T diff --git a/scikits/learn/gaussian_process/correlation_models.py b/scikits/learn/gaussian_process/correlation_models.py index 6ad2807b21bde73038673c24ffd87506e4dd05ed..52fc008935cc90230ef69a0a1177c6fe5792880f 100644 --- a/scikits/learn/gaussian_process/correlation_models.py +++ b/scikits/learn/gaussian_process/correlation_models.py @@ -39,23 +39,20 @@ def absolute_exponential(theta, d): An array with shape (n_eval, ) containing the values of the autocorrelation model. """ - theta = np.asanyarray(theta, dtype=np.float) - d = np.asanyarray(d, dtype=np.float) + d = np.abs(np.asanyarray(d, dtype=np.float)) if d.ndim > 1: n_features = d.shape[1] else: n_features = 1 + if theta.size == 1: - theta = np.repeat(theta, n_features) + return np.exp(- theta[0] * np.sum(d, axis=1)) elif theta.size != n_features: - raise ValueError("Length of theta must be 1 or " + str(n_features)) - - td = - theta.reshape(1, n_features) * abs(d) - r = np.exp(np.sum(td, 1)) - - return r + raise ValueError("Length of theta must be 1 or %s" % n_features) + else: + return np.exp(- np.sum(theta.reshape(1, n_features) * d, axis=1)) def squared_exponential(theta, d): @@ -92,15 +89,13 @@ def squared_exponential(theta, d): n_features = d.shape[1] else: n_features = 1 + if theta.size == 1: - theta = np.repeat(theta, n_features) + return np.exp(- theta[0] * np.sum(d**2, axis=1)) elif theta.size != n_features: - raise Exception("Length of theta must be 1 or " + str(n_features)) - - td = - theta.reshape(1, n_features) * d ** 2 - r = np.exp(np.sum(td, 1)) - - return r + raise ValueError("Length of theta must be 1 or %s" % n_features) + else: + return np.exp(- np.sum(theta.reshape(1, n_features) * d**2, axis=1)) def generalized_exponential(theta, d): @@ -138,16 +133,17 @@ def generalized_exponential(theta, d): n_features = d.shape[1] else: n_features = 1 + lth = theta.size if n_features > 1 and lth == 2: theta = np.hstack([np.repeat(theta[0], n_features), theta[1]]) elif lth != n_features + 1: - raise Exception("Length of theta must be 2 or " + str(n_features + 1)) + raise Exception("Length of theta must be 2 or %s" % (n_features + 1)) else: theta = theta.reshape(1, lth) - td = - theta[:, 0:-1].reshape(1, n_features) * abs(d) ** theta[:, -1] - r = np.exp(np.sum(td, 1)) + td = theta[:, 0:-1].reshape(1, n_features) * np.abs(d) ** theta[:, -1] + r = np.exp(- np.sum(td, 1)) return r @@ -184,9 +180,7 @@ def pure_nugget(theta, d): n_eval = d.shape[0] r = np.zeros(n_eval) - # The ones on the diagonal of the correlation matrix are enforced within - # the KrigingModel instanciation to allow multiple design sites in this - # ordinary least squares context. + r[np.all(d == 0., axis=1)] = 1. return r @@ -225,15 +219,15 @@ def cubic(theta, d): n_features = d.shape[1] else: n_features = 1 + lth = theta.size if lth == 1: - theta = np.repeat(theta, n_features)[np.newaxis][:] + td = np.abs(d) * theta elif lth != n_features: raise Exception("Length of theta must be 1 or " + str(n_features)) else: - theta = theta.reshape(1, n_features) + td = np.abs(d) * theta.reshape(1, n_features) - td = abs(d) * theta td[td > 1.] = 1. ss = 1. - td ** 2. * (3. - 2. * td) r = np.prod(ss, 1) @@ -275,15 +269,15 @@ def linear(theta, d): n_features = d.shape[1] else: n_features = 1 + lth = theta.size - if lth == 1: - theta = np.repeat(theta, n_features)[np.newaxis][:] + if lth == 1: + td = np.abs(d) * theta elif lth != n_features: - raise Exception("Length of theta must be 1 or " + str(n_features)) + raise Exception("Length of theta must be 1 or %s" % n_features) else: - theta = theta.reshape(1, n_features) + td = np.abs(d) * theta.reshape(1, n_features) - td = abs(d) * theta td[td > 1.] = 1. ss = 1. - td r = np.prod(ss, 1) diff --git a/scikits/learn/gaussian_process/gaussian_process.py b/scikits/learn/gaussian_process/gaussian_process.py index 0bce6699a6e97c61f136a64c8d1dc92ab1e1f232..3310fef044954070128a9c871d330fab060544ac 100644 --- a/scikits/learn/gaussian_process/gaussian_process.py +++ b/scikits/learn/gaussian_process/gaussian_process.py @@ -7,9 +7,11 @@ import numpy as np from scipy import linalg, optimize, rand -from ..base import BaseEstimator +from ..base import BaseEstimator, RegressorMixin from . import regression_models as regression from . import correlation_models as correlation +from ..cross_val import LeaveOneOut +from ..externals.joblib import Parallel, delayed MACHINE_EPSILON = np.finfo(np.double).eps if hasattr(linalg, 'solve_triangular'): # only in scipy since 0.9 @@ -20,7 +22,79 @@ else: return linalg.solve(x, y) -class GaussianProcess(BaseEstimator): +def compute_componentwise_l1_cross_distances(X): + """ + Computes the nonzero componentwise L1 cross-distances between the vectors + in X. + + Parameters + ---------- + + X: array_like + An array with shape (n_samples, n_features) + + Returns + ------- + + D: array with shape (n_samples * (n_samples - 1) / 2, n_features) + The array of componentwise L1 cross-distances. + + ij: arrays with shape (n_samples * (n_samples - 1) / 2, 2) + The indices i and j of the vectors in X associated to the cross- + distances in D: D[k] = np.abs(X[ij[k, 0]] - Y[ij[k, 1]]). + """ + X = np.atleast_2d(X) + n_samples, n_features = X.shape + n_nonzero_cross_dist = n_samples * (n_samples - 1) / 2 + ij = np.zeros([n_nonzero_cross_dist, 2]) + D = np.zeros([n_nonzero_cross_dist, n_features]) + ll = np.array([-1]) + for k in range(n_samples - 1): + ll = ll[-1] + 1 + range(n_samples - k - 1) + ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], + [np.array(range(k + 1, n_samples)).T]]).T + D[ll] = np.abs(X[k] - X[(k + 1):n_samples]) + + return D, ij.astype(np.int) + + +def compute_componentwise_l1_pairwise_distances(X, Y): + """ + Computes the componentwise L1 pairwise-distances between the vectors + in X and Y. + + Parameters + ---------- + + X: array_like + An array with shape (n_samples_X, n_features) + + Y: array_like, optional + An array with shape (n_samples_Y, n_features). + + Returns + ------- + + D: array with shape (n_samples_X * n_samples_Y, n_features) + The array of componentwise L1 pairwise-distances. + """ + X, Y = np.atleast_2d(X), np.atleast_2d(Y) + n_samples_X, n_features_X = X.shape + n_samples_Y, n_features_Y = Y.shape + if n_features_X != n_features_Y: + raise Exception("X and Y should have the same number of features!") + else: + n_features = n_features_X + D = np.zeros([n_samples_X * n_samples_Y, n_features]) + kk = np.arange(n_samples_Y).astype(np.int) + for k in range(n_samples_X): + D[kk] = X[k] - Y + kk = kk + n_samples_Y + + return D + + +class GaussianProcess(BaseEstimator, RegressorMixin): """ The Gaussian Process model class. @@ -85,7 +159,7 @@ class GaussianProcess(BaseEstimator): it uses theta0. normalize : boolean, optional - Design sites X and observations y are centered and reduced wrt + Input X and observations y are centered and reduced wrt means and standard deviations estimated from the n_samples observations provided. Default is normalize = True so that data is normalized to ease @@ -187,8 +261,8 @@ class GaussianProcess(BaseEstimator): Parameters ---------- X : double array_like - An array with shape (n_samples, n_features) with the design sites - at which observations were made. + An array with shape (n_samples, n_features) with the input at which + observations were made. y : double array_like An array with shape (n_features, ) with the observations of the @@ -206,7 +280,7 @@ class GaussianProcess(BaseEstimator): # Force data to 2D numpy.array X = np.atleast_2d(X) - y = np.asanyarray(y)[:, np.newaxis] + y = np.asanyarray(y).ravel()[:, np.newaxis] # Check shapes of DOE & observations n_samples_X, n_features = X.shape @@ -219,32 +293,24 @@ class GaussianProcess(BaseEstimator): # Normalize data or don't if self.normalize: - mean_X = np.mean(X, axis=0) - std_X = np.std(X, axis=0) - mean_y = np.mean(y, axis=0) - std_y = np.std(y, axis=0) - std_X[std_X == 0.] = 1. - std_y[std_y == 0.] = 1. + X_mean = np.mean(X, axis=0) + X_std = np.std(X, axis=0) + y_mean = np.mean(y, axis=0) + y_std = np.std(y, axis=0) + X_std[X_std == 0.] = 1. + y_std[y_std == 0.] = 1. + # center and scale X if necessary + X = (X - X_mean) / X_std + y = (y - y_mean) / y_std else: - mean_X = np.array([0.]) - std_X = np.array([1.]) - mean_y = np.array([0.]) - std_y = np.array([1.]) - - X = (X - mean_X) / std_X - y = (y - mean_y) / std_y + X_mean = np.zeros(1) + X_std = np.ones(1) + y_mean = np.zeros(1) + y_std = np.ones(1) # Calculate matrix of distances D between samples - mzmax = n_samples * (n_samples - 1) / 2 - ij = np.zeros([mzmax, 2]) - D = np.zeros([mzmax, n_features]) - ll = np.array([-1]) - for k in range(n_samples-1): - ll = ll[-1] + 1 + range(n_samples - k - 1) - ij[ll] = np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], - [np.arange(k + 1, n_samples).T]]).T - D[ll] = X[k] - X[(k + 1):n_samples] - if np.min(np.sum(np.abs(D), 1)) == 0. \ + D, ij = compute_componentwise_l1_cross_distances(X) + if np.min(np.sum(np.abs(D), axis=1)) == 0. \ and self.corr != correlation.pure_nugget: raise Exception("Multiple X are not allowed") @@ -273,8 +339,8 @@ class GaussianProcess(BaseEstimator): self.D = D self.ij = ij self.F = F - self.X_sc = np.concatenate([[mean_X], [std_X]]) - self.y_sc = np.concatenate([[mean_y], [std_y]]) + self.X_mean, self.X_std = X_mean, X_std + self.y_mean, self.y_std = y_mean, y_std # Determine Gaussian Process model parameters if self.thetaL is not None and self.thetaU is not None: @@ -356,7 +422,7 @@ class GaussianProcess(BaseEstimator): # Run input checks self._check_params() - # Check design & trial sites + # Check input shapes X = np.atleast_2d(X) n_eval, n_features_X = X.shape n_samples, n_features = self.X.shape @@ -370,31 +436,26 @@ class GaussianProcess(BaseEstimator): # No memory management # (evaluates all given points in a single batch run) - # Normalize trial sites - X = (X - self.X_sc[0][:]) / self.X_sc[1][:] + # Normalize input + X = (X - self.X_mean) / self.X_std # Initialize output y = np.zeros(n_eval) if eval_MSE: MSE = np.zeros(n_eval) - # Get distances to design sites - dx = np.zeros([n_eval * n_samples, n_features]) - kk = np.arange(n_samples).astype(int) - for k in range(n_eval): - dx[kk] = X[k] - self.X - kk = kk + n_samples + # Get pairwise componentwise L1-distances to the input training set + dx = compute_componentwise_l1_pairwise_distances(X, self.X) # Get regression function and correlation f = self.regr(X) r = self.corr(self.theta, dx).reshape(n_eval, n_samples) # Scaled predictor - y_ = np.dot(f, self.beta) \ - + np.dot(r, self.gamma) + y_ = np.dot(f, self.beta) + np.dot(r, self.gamma) # Predictor - y = (self.y_sc[0] + self.y_sc[1] * y_).ravel() + y = (self.y_mean + self.y_std * y_).ravel() # Mean Squared Error if eval_MSE: @@ -412,7 +473,7 @@ class GaussianProcess(BaseEstimator): self.G = par['G'] rt = solve_triangular(C, r.T, lower=True) - + if self.beta0 is None: # Universal Kriging u = solve_triangular(self.G.T, @@ -518,21 +579,8 @@ class GaussianProcess(BaseEstimator): if D is None: # Light storage mode (need to recompute D, ij and F) - if self.X.ndim > 1: - n_features = self.X.shape[1] - else: - n_features = 1 - mzmax = n_samples * (n_samples - 1) / 2 - ij = np.zeros([mzmax, n_features]) - D = np.zeros([mzmax, n_features]) - ll = np.array([-1]) - for k in range(n_samples-1): - ll = ll[-1] + 1 + range(n_samples - k - 1) - ij[ll] = \ - np.concatenate([[np.repeat(k, n_samples - k - 1, 0)], - [np.arange(k + 1, n_samples).T]]).T - D[ll] = self.X[k] - self.X[(k + 1):n_samples] - if min(sum(abs(D), 1)) == 0. \ + D, ij = compute_componentwise_l1_cross_distances(X) + if np.min(np.sum(np.abs(D), axis=1)) == 0. \ and self.corr != correlation.pure_nugget: raise Exception("Multiple X are not allowed") F = self.regr(self.X) @@ -540,8 +588,8 @@ class GaussianProcess(BaseEstimator): # Set up R r = self.corr(theta, D) R = np.eye(n_samples) * (1. + self.nugget) - R[ij.astype(int)[:, 0], ij.astype(int)[:, 1]] = r - R[ij.astype(int)[:, 1], ij.astype(int)[:, 0]] = r + R[ij[:, 0], ij[:, 1]] = r + R[ij[:, 1], ij[:, 0]] = r # Cholesky decomposition of R try: @@ -590,7 +638,7 @@ class GaussianProcess(BaseEstimator): # Compute/Organize output reduced_likelihood_function_value = - sigma2.sum() * detR - par['sigma2'] = sigma2 * self.y_sc[1] ** 2. + par['sigma2'] = sigma2 * self.y_std ** 2. par['beta'] = beta par['gamma'] = solve_triangular(C.T, rho) par['C'] = C @@ -641,8 +689,9 @@ class GaussianProcess(BaseEstimator): if self.optimizer == 'fmin_cobyla': - minus_reduced_likelihood_function = lambda log10t: \ - - self.reduced_likelihood_function(theta=10. ** log10t)[0] + def minus_reduced_likelihood_function(log10t): + return - self.reduced_likelihood_function(theta=10. + ** log10t)[0] constraints = [] for i in range(self.theta0.size): @@ -687,7 +736,7 @@ class GaussianProcess(BaseEstimator): if self.verbose and self.random_start > 1: if (20 * k) / self.random_start > percent_completed: percent_completed = (20 * k) / self.random_start - print str(5 * percent_completed) + "% completed" + print "%s completed" % (5 * percent_completed) optimal_rlf_value = best_optimal_rlf_value optimal_par = best_optimal_par @@ -723,11 +772,14 @@ class GaussianProcess(BaseEstimator): self.theta0 = np.atleast_2d(theta_iso) self.thetaL = np.atleast_2d(thetaL[0, i]) self.thetaU = np.atleast_2d(thetaU[0, i]) - self.corr = lambda t, d: \ - corr(np.atleast_2d(np.hstack([ + + def corr_cut(t, d): + return corr(np.atleast_2d(np.hstack([ optimal_theta[0][0:i], t[0], optimal_theta[0][(i + 1)::]])), d) + + self.corr = corr_cut optimal_theta[0, i], optimal_rlf_value, optimal_par = \ self.arg_max_reduced_likelihood_function() @@ -745,28 +797,6 @@ class GaussianProcess(BaseEstimator): return optimal_theta, optimal_rlf_value, optimal_par - def score(self, X_test, y_test): - """ - This score function returns the mean deviation of the Gaussian Process - model evaluated onto a test dataset. - - Parameters - ---------- - X_test : array_like - The feature test dataset with shape (n_tests, n_features). - - y_test : array_like - The target test dataset (n_tests, ). - - Returns - ------- - score_value : array_like - The mean of the deviations between the prediction and the targets: - mean(y_pred - y_test). - """ - - return (self.predict(X_test, eval_MSE=False) - y_test).mean() - def _check_params(self): # Check regression model diff --git a/scikits/learn/gaussian_process/regression_models.py b/scikits/learn/gaussian_process/regression_models.py index 4ba931b8e3a59b66cab678a647da42df41d22c9f..ea0eda50677b04a39abd65424aa105e30a11826b 100644 --- a/scikits/learn/gaussian_process/regression_models.py +++ b/scikits/learn/gaussian_process/regression_models.py @@ -31,11 +31,9 @@ def constant(x): An array with shape (n_eval, p) with the values of the regression model. """ - x = np.asanyarray(x, dtype=np.float) n_eval = x.shape[0] f = np.ones([n_eval, 1]) - return f @@ -57,11 +55,9 @@ def linear(x): An array with shape (n_eval, p) with the values of the regression model. """ - x = np.asanyarray(x, dtype=np.float) n_eval = x.shape[0] f = np.hstack([np.ones([n_eval, 1]), x]) - return f @@ -88,7 +84,7 @@ def quadratic(x): x = np.asanyarray(x, dtype=np.float) n_eval, n_features = x.shape f = np.hstack([np.ones([n_eval, 1]), x]) - for k in range(n_features): + for k in range(n_features): f = np.hstack([f, x[:, k, np.newaxis] * x[:, k:]]) return f diff --git a/scikits/learn/gaussian_process/tests/test_gaussian_process.py b/scikits/learn/gaussian_process/tests/test_gaussian_process.py index 2f71070fa999dbd11f61ee9a5fcf012265b58d10..c285022769e272960388abb71e34c53cf898cf16 100644 --- a/scikits/learn/gaussian_process/tests/test_gaussian_process.py +++ b/scikits/learn/gaussian_process/tests/test_gaussian_process.py @@ -20,7 +20,6 @@ def test_1d(regr=regression.constant, corr=correlation.squared_exponential, Test the interpolating property. """ - f = lambda x: x * np.sin(x) X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T y = f(X).ravel() @@ -40,7 +39,6 @@ def test_2d(regr=regression.constant, corr=correlation.squared_exponential, Test the interpolating property. """ - b, kappa, e = 5., .5, .1 g = lambda x: b - x[:, 1] - kappa * (x[:, 0] - e) ** 2. X = np.array([[-4.61611719, -6.00099547], @@ -80,7 +78,6 @@ def test_ordinary_kriging(): Repeat test_1d and test_2d with given regression weights (beta0) for different regression models (Ordinary Kriging). """ - test_1d(regr='linear', beta0=[0., 0.5]) test_1d(regr='quadratic', beta0=[0., 0.5, 0.5]) test_2d(regr='linear', beta0=[0., 0.5, 0.5])