diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index 8daf18e9fe2515f0fdc12facd5182e8c1c66166b..d8dcb47a503efd65182668ee68be60d506503df3 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -55,7 +55,8 @@ def _cov(X, shrinkage=None): if shrinkage == 'auto': sc = StandardScaler() # standardize features X = sc.fit_transform(X) - s = sc.scale_ * ledoit_wolf(X)[0] * sc.scale_ # scale back + s = ledoit_wolf(X)[0] + s = sc.scale_[:, np.newaxis] * s * sc.scale_[np.newaxis, :] # rescale elif shrinkage == 'empirical': s = empirical_covariance(X) else: diff --git a/sklearn/tests/test_discriminant_analysis.py b/sklearn/tests/test_discriminant_analysis.py index f704f4b427fffbd83b0a9e4a2fb466eaacdcc693..cbd911a0f79682a8b740c595efa23ecedcc377b2 100644 --- a/sklearn/tests/test_discriminant_analysis.py +++ b/sklearn/tests/test_discriminant_analysis.py @@ -16,6 +16,7 @@ from sklearn.utils.testing import ignore_warnings from sklearn.datasets import make_blobs from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis +from sklearn.discriminant_analysis import _cov # import reload @@ -116,7 +117,7 @@ def test_lda_priors(): priors = np.array([0.5, 0.6]) prior_norm = np.array([0.45, 0.55]) clf = LinearDiscriminantAnalysis(priors=priors) - clf.fit(X, y) + assert_warns(UserWarning, clf.fit, X, y) assert_array_almost_equal(clf.priors_, prior_norm, 2) @@ -325,3 +326,17 @@ def test_deprecated_lda_qda_deprecation(): qda = assert_warns(DeprecationWarning, import_qda_module) assert qda.QDA is QuadraticDiscriminantAnalysis + + +def test_covariance(): + x, y = make_blobs(n_samples=100, n_features=5, + centers=1, random_state=42) + + # make features correlated + x = np.dot(x, np.arange(x.shape[1] ** 2).reshape(x.shape[1], x.shape[1])) + + c_e = _cov(x, 'empirical') + assert_almost_equal(c_e, c_e.T) + + c_s = _cov(x, 'auto') + assert_almost_equal(c_s, c_s.T)