diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 37661851f5aca6d4f42ae66365358e2c4b0c4c60..96658c0e9af323610e2d84bdfcdec5aaa0303dbb 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -2,18 +2,16 @@ import numpy as np import scipy.sparse as sp import warnings -_FLOAT_CODES = np.typecodes['AllFloat'] - - def assert_all_finite(X): """Throw a ValueError if X contains NaN or infinity. Input MUST be an np.ndarray instance or a scipy.sparse matrix.""" - # O(n) time, O(1) solution. XXX: will fail if the sum over X is - # *extremely* large. A proper solution would be a C-level loop to check - # each element. - if X.dtype.char in _FLOAT_CODES and not np.isfinite(X.sum()): - raise ValueError("array contains NaN or infinity") + # First try an O(n) time, O(1) space solution for the common case that + # there everything is finite; fall back to O(n) space np.isfinite to + # prevent false positives from overflow in sum method. + if X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.sum()) \ + and not np.isfinite(X).all(): + raise ValueError("array contains NaN or infinity") def safe_asanyarray(X, dtype=None, order=None):