diff --git a/scikits/learn/lda.py b/scikits/learn/lda.py index 904a314634814727fd8580c4f35518d76ab37d82..01caf0e1bea59dc9e6d911a063222cf4cce5b7db 100644 --- a/scikits/learn/lda.py +++ b/scikits/learn/lda.py @@ -84,12 +84,19 @@ class LDA(BaseEstimator, ClassifierMixin): self._set_params(**params) X = np.asanyarray(X) y = np.asanyarray(y) + if y.dtype.char.lower() not in ('i', 'l'): + # We need integer values to be able to use + # ndimage.measurements and np.bincount on numpy > 2.0 + y = y.astype(np.int) if X.ndim != 2: raise ValueError('X must be a 2D array') + if X.shape[0] != y.shape[0]: + raise ValueError( + 'Incompatible shapes: X has %s samples, while y ' + 'has %s' % (X.shape[0], y.shape[0])) n_samples = X.shape[0] n_features = X.shape[1] - # We need int32 to be able to use ndimage.measurements - classes = np.unique(y).astype(np.int32) + classes = np.unique(y) n_classes = classes.size if n_classes < 2: raise ValueError('y has less than 2 classes') diff --git a/scikits/learn/qda.py b/scikits/learn/qda.py index c65fcd8ee18f643cf1cb989163184acfa6c21301..5cbeeb6537785ea75eb6717feba4b26d694254d8 100644 --- a/scikits/learn/qda.py +++ b/scikits/learn/qda.py @@ -93,9 +93,15 @@ class QDA(BaseEstimator, ClassifierMixin): if X.ndim!=2: raise exceptions.ValueError('X must be a 2D array') if X.shape[0] != y.shape[0]: - raise ValueError("Incompatible shapes") + raise ValueError( + 'Incompatible shapes: X has %s samples, while y ' + 'has %s' % (X.shape[0], y.shape[0])) + if y.dtype.char.lower() not in ('i', 'l'): + # We need integer values to be able to use + # ndimage.measurements and np.bincount on numpy > 2.0 + y = y.astype(np.int) n_samples, n_features = X.shape - classes = np.unique(y).astype(np.int32) + classes = np.unique(y) n_classes = classes.size if n_classes < 2: raise exceptions.ValueError('y has less than 2 classes')