diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 6058cedfda6dcc56e379a407213c6f5281fbedb5..3b72b57c4e8389f92b261d15f2542f5233a93a2d 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -705,6 +705,11 @@ class KMeans(BaseEstimator): raise ValueError("Incorrect number of features. " "Got %d features, expected %d" % ( n_features, expected_n_features)) + if not X.dtype.kind is 'f': + warnings.warn("Got data type %s, converted to float " + "to avoid overflows" % X.dtype) + X = X.astype(np.float) + return X def _check_fitted(self): diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 3990cd5a0f67e4a94ba30f43f4586e4ab43c3728..9dd64c3f153c73aa98a1bfd38467fe6596dd5fa5 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -1,6 +1,7 @@ """Testing for K-means""" import numpy as np +import warnings from scipy import sparse as sp from numpy.testing import assert_equal from numpy.testing import assert_array_equal @@ -39,6 +40,15 @@ def test_square_norms(): x_squared_norms_from_csr, 5) +def test_kmeans_dtype(): + X = np.random.normal(size=(40, 2)) + X = (X * 10).astype(np.uint8) + km = KMeans(n_init=1).fit(X) + with warnings.catch_warnings(record=True) as w: + assert_array_equal(km.labels_, km.predict(X)) + assert_equal(len(w), 1) + + def test_labels_assignement_and_inertia(): # pure numpy implementation as easily auditable reference gold # implementation