From 027de8f9685ce9ab2481314f4db67342f24b3a0e Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa <fabian.pedregosa@inria.fr> Date: Wed, 15 Dec 2010 09:11:33 +0100 Subject: [PATCH] FIX: backwards compatibility for scipy <= 0.8 ndimage.measurement.sum does not accept int64 on these versions. --- scikits/learn/lda.py | 14 +++++++++----- scikits/learn/qda.py | 14 +++++++++----- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py index 01caf0e1be..ebd189a3f8 100644 --- a/scikits/learn/lda.py +++ b/scikits/learn/lda.py @@ -84,10 +84,14 @@ class LDA(BaseEstimator, ClassifierMixin): self._set_params(**params) X = np.asanyarray(X) y = np.asanyarray(y) - if y.dtype.char.lower() not in ('i', 'l'): + if y.dtype.char.lower() not in ('b', 'h', 'i'): # We need integer values to be able to use - # ndimage.measurements and np.bincount on numpy > 2.0 - y = y.astype(np.int) + # ndimage.measurements and np.bincount on numpy >= 2.0. + # We currently support (u)int8, (u)int16 and (u)int32. + # Note that versions of scipy >= 0.8 can also accept + # (u)int64. We however don't support it for backwards + # compatibility. + y = y.astype(np.int32) if X.ndim != 2: raise ValueError('X must be a 2D array') if X.shape[0] != y.shape[0]: @@ -102,8 +106,8 @@ class LDA(BaseEstimator, ClassifierMixin): raise ValueError('y has less than 2 classes') classes_indices = [(y == c).ravel() for c in classes] if self.priors is None: - counts = np.array(ndimage.measurements.sum(np.ones(len(y)), - y, index=classes)) + counts = np.array(ndimage.measurements.sum( + np.ones(n_samples, dtype=y.dtype), y, index=classes)) self.priors_ = counts / float(n_samples) else: self.priors_ = self.priors diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py index 5cbeeb6537..40bdfe70e6 100644 --- a/scikits/learn/qda.py +++ b/scikits/learn/qda.py @@ -96,10 +96,14 @@ class QDA(BaseEstimator, ClassifierMixin): raise ValueError( 'Incompatible shapes: X has %s samples, while y ' 'has %s' % (X.shape[0], y.shape[0])) - if y.dtype.char.lower() not in ('i', 'l'): + if y.dtype.char.lower() not in ('b', 'h', 'i'): # We need integer values to be able to use - # ndimage.measurements and np.bincount on numpy > 2.0 - y = y.astype(np.int) + # ndimage.measurements and np.bincount on numpy >= 2.0. + # We currently support (u)int8, (u)int16 and (u)int32. + # Note that versions of scipy >= 0.8 can also accept + # (u)int64. We however don't support it for backwards + # compatibility. + y = y.astype(np.int32) n_samples, n_features = X.shape classes = np.unique(y) n_classes = classes.size @@ -107,8 +111,8 @@ class QDA(BaseEstimator, ClassifierMixin): raise exceptions.ValueError('y has less than 2 classes') classes_indices = [(y == c).ravel() for c in classes] if self.priors is None: - counts = np.array(ndimage.measurements.sum(np.ones(len(y)), - y, index=classes)) + counts = np.array(ndimage.measurements.sum( + np.ones(n_samples, dtype=y.dtype), y, index=classes)) self.priors_ = counts / float(n_samples) else: self.priors_ = self.priors -- GitLab