From defdb91bb6772c38722f9ff7ed04ac4c3eee0273 Mon Sep 17 00:00:00 2001
From: Fabian Pedregosa <fabian.pedregosa@inria.fr>
Date: Tue, 14 Dec 2010 14:35:45 +0100
Subject: [PATCH] FIX lda, qda: new numpy.bincount requires integer arguments.

Fixed this while allowing y to have both signed and unsigned int/long
values to avoid overflow.
---
 scikits/learn/lda.py | 11 +++++++++--
 scikits/learn/qda.py | 10 ++++++++--
 2 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py
index 904a314634..01caf0e1be 100644
--- a/scikits/learn/lda.py
+++ b/scikits/learn/lda.py
@@ -84,12 +84,19 @@ class LDA(BaseEstimator, ClassifierMixin):
         self._set_params(**params)
         X = np.asanyarray(X)
         y = np.asanyarray(y)
+        if y.dtype.char.lower() not in ('i', 'l'):
+            # We need integer values to be able to use
+            # ndimage.measurements and np.bincount on numpy > 2.0
+            y = y.astype(np.int)
         if X.ndim != 2:
             raise ValueError('X must be a 2D array')
+        if X.shape[0] != y.shape[0]:
+            raise ValueError(
+                'Incompatible shapes: X has %s samples, while y '
+                'has %s' % (X.shape[0], y.shape[0]))
         n_samples = X.shape[0]
         n_features = X.shape[1]
-        # We need int32 to be able to use ndimage.measurements
-        classes = np.unique(y).astype(np.int32)
+        classes = np.unique(y)
         n_classes = classes.size
         if n_classes < 2:
             raise ValueError('y has less than 2 classes')
diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py
index c65fcd8ee1..5cbeeb6537 100644
--- a/scikits/learn/qda.py
+++ b/scikits/learn/qda.py
@@ -93,9 +93,15 @@ class QDA(BaseEstimator, ClassifierMixin):
         if X.ndim!=2:
             raise exceptions.ValueError('X must be a 2D array')
         if X.shape[0] != y.shape[0]:
-            raise ValueError("Incompatible shapes")
+            raise ValueError(
+                'Incompatible shapes: X has %s samples, while y '
+                'has %s' % (X.shape[0], y.shape[0]))
+        if y.dtype.char.lower() not in ('i', 'l'):
+            # We need integer values to be able to use
+            # ndimage.measurements and np.bincount on numpy > 2.0
+            y = y.astype(np.int)
         n_samples, n_features = X.shape
-        classes = np.unique(y).astype(np.int32)
+        classes = np.unique(y)
         n_classes = classes.size
         if n_classes < 2:
             raise exceptions.ValueError('y has less than 2 classes')
-- 
GitLab