From defdb91bb6772c38722f9ff7ed04ac4c3eee0273 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Tue, 14 Dec 2010 14:35:45 +0100 Subject: [PATCH] FIX lda, qda: new numpy.bincount requires integer arguments. Fixed this while allowing y to have both signed and unsigned int/long values to avoid overflow. --- scikits/learn/lda.py | 11 +++++++++-- scikits/learn/qda.py | 10 ++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py index 904a314634..01caf0e1be 100644 --- a/scikits/learn/lda.py +++ b/scikits/learn/lda.py @@ -84,12 +84,19 @@ class LDA(BaseEstimator, ClassifierMixin): self._set_params(**params) X = np.asanyarray(X) y = np.asanyarray(y) + if y.dtype.char.lower() not in ('i', 'l'): + # We need integer values to be able to use + # ndimage.measurements and np.bincount on numpy > 2.0 + y = y.astype(np.int) if X.ndim != 2: raise ValueError('X must be a 2D array') + if X.shape[0] != y.shape[0]: + raise ValueError( + 'Incompatible shapes: X has %s samples, while y ' + 'has %s' % (X.shape[0], y.shape[0])) n_samples = X.shape[0] n_features = X.shape[1] - # We need int32 to be able to use ndimage.measurements - classes = np.unique(y).astype(np.int32) + classes = np.unique(y) n_classes = classes.size if n_classes < 2: raise ValueError('y has less than 2 classes') diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py index c65fcd8ee1..5cbeeb6537 100644 --- a/scikits/learn/qda.py +++ b/scikits/learn/qda.py @@ -93,9 +93,15 @@ class QDA(BaseEstimator, ClassifierMixin): if X.ndim!=2: raise exceptions.ValueError('X must be a 2D array') if X.shape[0] != y.shape[0]: - raise ValueError("Incompatible shapes") + raise ValueError( + 'Incompatible shapes: X has %s samples, while y ' + 'has %s' % (X.shape[0], y.shape[0])) + if y.dtype.char.lower() not in ('i', 'l'): + # We need integer values to be able to use + # ndimage.measurements and np.bincount on numpy > 2.0 + y = y.astype(np.int) n_samples, n_features = X.shape - classes = np.unique(y).astype(np.int32) + classes = np.unique(y) n_classes = classes.size if n_classes < 2: raise exceptions.ValueError('y has less than 2 classes') -- GitLab