From edeb3af217865f7f3339396ee7c98274a8f3973e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Dupr=C3=A9=20la=20Tour?= <TomDLT@users.noreply.github.com> Date: Fri, 23 Jun 2017 21:49:29 +0200 Subject: [PATCH] Deprecate n_iter in SGDClassifier and implement max_iter (#5036) --- benchmarks/bench_covertype.py | 2 +- benchmarks/bench_sgd_regression.py | 39 ++- benchmarks/bench_sparsify.py | 9 +- doc/modules/kernel_approximation.rst | 6 +- doc/modules/linear_model.rst | 3 +- doc/modules/sgd.rst | 6 +- .../text_analytics/working_with_text_data.rst | 3 +- doc/whats_new.rst | 11 +- examples/linear_model/plot_sgd_iris.py | 3 +- .../plot_sgd_separating_hyperplane.py | 2 +- .../linear_model/plot_sgd_weighted_samples.py | 4 +- .../decomposition/tests/test_kernel_pca.py | 12 +- sklearn/ensemble/tests/test_bagging.py | 6 +- sklearn/ensemble/tests/test_base.py | 21 +- .../tests/test_from_model.py | 23 +- sklearn/linear_model/passive_aggressive.py | 78 ++++- sklearn/linear_model/perceptron.py | 35 ++- sklearn/linear_model/sgd_fast.pyx | 87 ++++-- sklearn/linear_model/stochastic_gradient.py | 283 ++++++++++++------ sklearn/linear_model/tests/test_huber.py | 9 +- .../tests/test_passive_aggressive.py | 83 +++-- sklearn/linear_model/tests/test_perceptron.py | 10 +- sklearn/linear_model/tests/test_sgd.py | 228 +++++++++----- sklearn/model_selection/tests/test_search.py | 4 +- .../model_selection/tests/test_validation.py | 6 +- sklearn/tests/test_learning_curve.py | 3 +- sklearn/tests/test_multiclass.py | 14 +- sklearn/tests/test_multioutput.py | 20 +- sklearn/utils/estimator_checks.py | 88 +++--- sklearn/utils/weight_vector.pyx | 1 - 30 files changed, 698 insertions(+), 401 deletions(-) diff --git a/benchmarks/bench_covertype.py b/benchmarks/bench_covertype.py index 5d995c70ef..d5ee0c04eb 100644 --- a/benchmarks/bench_covertype.py +++ b/benchmarks/bench_covertype.py @@ -102,7 +102,7 @@ ESTIMATORS = { 'ExtraTrees': ExtraTreesClassifier(n_estimators=20), 'RandomForest': RandomForestClassifier(n_estimators=20), 'CART': DecisionTreeClassifier(min_samples_split=5), - 'SGD': SGDClassifier(alpha=0.001, n_iter=2), + 'SGD': SGDClassifier(alpha=0.001, max_iter=1000, tol=1e-3), 'GaussianNB': GaussianNB(), 'liblinear': LinearSVC(loss="l2", penalty="l2", C=1000, dual=False, tol=1e-3), diff --git a/benchmarks/bench_sgd_regression.py b/benchmarks/bench_sgd_regression.py index e66f656114..d0b9f43f7f 100644 --- a/benchmarks/bench_sgd_regression.py +++ b/benchmarks/bench_sgd_regression.py @@ -1,12 +1,3 @@ -""" -Benchmark for SGD regression - -Compares SGD regression against coordinate descent and Ridge -on synthetic data. -""" - -print(__doc__) - # Author: Peter Prettenhofer <peter.prettenhofer@gmail.com> # License: BSD 3 clause @@ -21,10 +12,20 @@ from sklearn.linear_model import Ridge, SGDRegressor, ElasticNet from sklearn.metrics import mean_squared_error from sklearn.datasets.samples_generator import make_regression +""" +Benchmark for SGD regression + +Compares SGD regression against coordinate descent and Ridge +on synthetic data. +""" + +print(__doc__) + if __name__ == "__main__": list_n_samples = np.linspace(100, 10000, 5).astype(np.int) list_n_features = [10, 100, 1000] n_test = 1000 + max_iter = 1000 noise = 0.1 alpha = 0.01 sgd_results = np.zeros((len(list_n_samples), len(list_n_features), 2)) @@ -70,30 +71,28 @@ if __name__ == "__main__": tstart = time() clf.fit(X_train, y_train) elnet_results[i, j, 0] = mean_squared_error(clf.predict(X_test), - y_test) + y_test) elnet_results[i, j, 1] = time() - tstart gc.collect() print("- benchmarking SGD") - n_iter = np.ceil(10 ** 4.0 / n_train) clf = SGDRegressor(alpha=alpha / n_train, fit_intercept=False, - n_iter=n_iter, learning_rate="invscaling", - eta0=.01, power_t=0.25) + max_iter=max_iter, learning_rate="invscaling", + eta0=.01, power_t=0.25, tol=1e-3) tstart = time() clf.fit(X_train, y_train) sgd_results[i, j, 0] = mean_squared_error(clf.predict(X_test), - y_test) + y_test) sgd_results[i, j, 1] = time() - tstart gc.collect() - print("n_iter", n_iter) + print("max_iter", max_iter) print("- benchmarking A-SGD") - n_iter = np.ceil(10 ** 4.0 / n_train) clf = SGDRegressor(alpha=alpha / n_train, fit_intercept=False, - n_iter=n_iter, learning_rate="invscaling", - eta0=.002, power_t=0.05, - average=(n_iter * n_train // 2)) + max_iter=max_iter, learning_rate="invscaling", + eta0=.002, power_t=0.05, tol=1e-3, + average=(max_iter * n_train // 2)) tstart = time() clf.fit(X_train, y_train) @@ -107,7 +106,7 @@ if __name__ == "__main__": tstart = time() clf.fit(X_train, y_train) ridge_results[i, j, 0] = mean_squared_error(clf.predict(X_test), - y_test) + y_test) ridge_results[i, j, 1] = time() - tstart # Plot results diff --git a/benchmarks/bench_sparsify.py b/benchmarks/bench_sparsify.py index 6affa4f3eb..42d7eeb891 100644 --- a/benchmarks/bench_sparsify.py +++ b/benchmarks/bench_sparsify.py @@ -63,7 +63,7 @@ print("input data sparsity: %f" % sparsity_ratio(X)) coef = 3 * np.random.randn(n_features) inds = np.arange(n_features) np.random.shuffle(inds) -coef[inds[n_features/2:]] = 0 # sparsify coef +coef[inds[n_features // 2:]] = 0 # sparsify coef print("true coef sparsity: %f" % sparsity_ratio(coef)) y = np.dot(X, coef) @@ -72,12 +72,13 @@ y += 0.01 * np.random.normal((n_samples,)) # Split data in train set and test set n_samples = X.shape[0] -X_train, y_train = X[:n_samples / 2], y[:n_samples / 2] -X_test, y_test = X[n_samples / 2:], y[n_samples / 2:] +X_train, y_train = X[:n_samples // 2], y[:n_samples // 2] +X_test, y_test = X[n_samples // 2:], y[n_samples // 2:] print("test data sparsity: %f" % sparsity_ratio(X_test)) ############################################################################### -clf = SGDRegressor(penalty='l1', alpha=.2, fit_intercept=True, n_iter=2000) +clf = SGDRegressor(penalty='l1', alpha=.2, fit_intercept=True, max_iter=2000, + tol=None) clf.fit(X_train, y_train) print("model sparsity: %f" % sparsity_ratio(clf.coef_)) diff --git a/doc/modules/kernel_approximation.rst b/doc/modules/kernel_approximation.rst index 72363faf66..ae7dd14dea 100644 --- a/doc/modules/kernel_approximation.rst +++ b/doc/modules/kernel_approximation.rst @@ -63,9 +63,9 @@ a linear algorithm, for example a linear SVM:: >>> clf.fit(X_features, y) SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15, - learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1, - penalty='l2', power_t=0.5, random_state=None, shuffle=True, - verbose=0, warm_start=False) + learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None, + n_jobs=1, penalty='l2', power_t=0.5, random_state=None, + shuffle=True, tol=None, verbose=0, warm_start=False) >>> clf.score(X_features, y) 1.0 diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index b3e82b56a4..0696b4f9f5 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -1265,7 +1265,8 @@ This way, we can solve the XOR problem with a linear classifier:: [1, 0, 1, 0], [1, 1, 0, 0], [1, 1, 1, 1]]) - >>> clf = Perceptron(fit_intercept=False, n_iter=10, shuffle=False).fit(X, y) + >>> clf = Perceptron(fit_intercept=False, max_iter=10, tol=None, + ... shuffle=False).fit(X, y) And the classifier "predictions" are perfect:: diff --git a/doc/modules/sgd.rst b/doc/modules/sgd.rst index e8febda201..4bdb218f88 100644 --- a/doc/modules/sgd.rst +++ b/doc/modules/sgd.rst @@ -63,9 +63,9 @@ for the training samples:: >>> clf.fit(X, y) SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15, - learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1, - penalty='l2', power_t=0.5, random_state=None, shuffle=True, - verbose=0, warm_start=False) + learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None, + n_jobs=1, penalty='l2', power_t=0.5, random_state=None, + shuffle=True, tol=None, verbose=0, warm_start=False) After being fitted, the model can then be used to predict new values:: diff --git a/doc/tutorial/text_analytics/working_with_text_data.rst b/doc/tutorial/text_analytics/working_with_text_data.rst index b23d4ad98e..d7a74d5304 100644 --- a/doc/tutorial/text_analytics/working_with_text_data.rst +++ b/doc/tutorial/text_analytics/working_with_text_data.rst @@ -352,7 +352,8 @@ classifier object into our pipeline:: >>> text_clf = Pipeline([('vect', CountVectorizer()), ... ('tfidf', TfidfTransformer()), ... ('clf', SGDClassifier(loss='hinge', penalty='l2', - ... alpha=1e-3, n_iter=5, random_state=42)), + ... alpha=1e-3, random_state=42, + ... max_iter=5, tol=None)), ... ]) >>> text_clf.fit(twenty_train.data, twenty_train.target) # doctest: +ELLIPSIS Pipeline(...) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3e1e498061..04480f7879 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -145,6 +145,15 @@ Enhancements do not set attributes on the estimator. :issue:`7533` by :user:`Ekaterina Krivich <kiote>`. + - :class:`linear_model.SGDClassifier`, :class:`linear_model.SGDRegressor`, + :class:`linear_model.PassiveAggressiveClassifier`, + :class:`linear_model.PassiveAggressiveRegressor` and + :class:`linear_model.Perceptron` now expose a ``max_iter`` and + ``tol`` parameters, to handle convergence more precisely. + ``n_iter`` parameter is deprecated, and the fitted estimator exposes + a ``n_iter_`` attribute, with actual number of iterations before + convergence. By `Tom Dupre la Tour`_. + - For sparse matrices, :func:`preprocessing.normalize` with ``return_norm=True`` will now raise a ``NotImplementedError`` with 'l1' or 'l2' norm and with norm 'max' the norms returned will be the same as for dense matrices. @@ -1334,7 +1343,6 @@ Birodkar, Vikram, Villu Ruusmann, Vinayak Mehta, walter, waterponey, Wenhua Yang, Wenjian Huang, Will Welch, wyseguy7, xyguo, yanlend, Yaroslav Halchenko, yelite, Yen, YenChenLin, Yichuan Liu, Yoav Ram, Yoshiki, Zheng RuiFeng, zivori, Óscar Nájera - .. currentmodule:: sklearn .. _changes_0_17_1: @@ -1375,6 +1383,7 @@ Bug fixes :class:`decomposition.LatentDirichletAllocation` model. See :issue:`6258` By Chyi-Kwei Yau. + .. _changes_0_17: Version 0.17 diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 0da926fe69..0dddf74757 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -38,7 +38,7 @@ X = (X - mean) / std h = .02 # step size in the mesh -clf = SGDClassifier(alpha=0.001, n_iter=100).fit(X, y) +clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) # create a mesh to plot in x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 @@ -76,6 +76,7 @@ def plot_hyperplane(c, color): plt.plot([xmin, xmax], [line(xmin), line(xmax)], ls="--", color=color) + for i, color in zip(clf.classes_, colors): plot_hyperplane(i, color) plt.legend() diff --git a/examples/linear_model/plot_sgd_separating_hyperplane.py b/examples/linear_model/plot_sgd_separating_hyperplane.py index c47a264485..3d3967ea0d 100644 --- a/examples/linear_model/plot_sgd_separating_hyperplane.py +++ b/examples/linear_model/plot_sgd_separating_hyperplane.py @@ -18,7 +18,7 @@ from sklearn.datasets.samples_generator import make_blobs X, Y = make_blobs(n_samples=50, centers=2, random_state=0, cluster_std=0.60) # fit the model -clf = SGDClassifier(loss="hinge", alpha=0.01, n_iter=200, fit_intercept=True) +clf = SGDClassifier(loss="hinge", alpha=0.01, max_iter=200, fit_intercept=True) clf.fit(X, Y) # plot the line, the points, and the nearest vectors to the plane diff --git a/examples/linear_model/plot_sgd_weighted_samples.py b/examples/linear_model/plot_sgd_weighted_samples.py index 2f53d86166..3617d81b0a 100644 --- a/examples/linear_model/plot_sgd_weighted_samples.py +++ b/examples/linear_model/plot_sgd_weighted_samples.py @@ -27,14 +27,14 @@ plt.scatter(X[:, 0], X[:, 1], c=y, s=sample_weight, alpha=0.9, cmap=plt.cm.bone, edgecolor='black') # fit the unweighted model -clf = linear_model.SGDClassifier(alpha=0.01, n_iter=100) +clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) no_weights = plt.contour(xx, yy, Z, levels=[0], linestyles=['solid']) # fit the weighted model -clf = linear_model.SGDClassifier(alpha=0.01, n_iter=100) +clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100) clf.fit(X, y, sample_weight=sample_weight) Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index 722d7ec0e0..63281ce33d 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -178,7 +178,8 @@ def test_gridsearch_pipeline(): X, y = make_circles(n_samples=400, factor=.3, noise=.05, random_state=0) kpca = KernelPCA(kernel="rbf", n_components=2) - pipeline = Pipeline([("kernel_pca", kpca), ("Perceptron", Perceptron())]) + pipeline = Pipeline([("kernel_pca", kpca), + ("Perceptron", Perceptron(max_iter=5))]) param_grid = dict(kernel_pca__gamma=2. ** np.arange(-2, 2)) grid_search = GridSearchCV(pipeline, cv=3, param_grid=param_grid) grid_search.fit(X, y) @@ -191,8 +192,9 @@ def test_gridsearch_pipeline_precomputed(): X, y = make_circles(n_samples=400, factor=.3, noise=.05, random_state=0) kpca = KernelPCA(kernel="precomputed", n_components=2) - pipeline = Pipeline([("kernel_pca", kpca), ("Perceptron", Perceptron())]) - param_grid = dict(Perceptron__n_iter=np.arange(1, 5)) + pipeline = Pipeline([("kernel_pca", kpca), + ("Perceptron", Perceptron(max_iter=5))]) + param_grid = dict(Perceptron__max_iter=np.arange(1, 5)) grid_search = GridSearchCV(pipeline, cv=3, param_grid=param_grid) X_kernel = rbf_kernel(X, gamma=2.) grid_search.fit(X_kernel, y) @@ -205,7 +207,7 @@ def test_nested_circles(): random_state=0) # 2D nested circles are not linearly separable - train_score = Perceptron().fit(X, y).score(X, y) + train_score = Perceptron(max_iter=5).fit(X, y).score(X, y) assert_less(train_score, 0.8) # Project the circles data into the first 2 components of a RBF Kernel @@ -218,5 +220,5 @@ def test_nested_circles(): X_kpca = kpca.fit_transform(X) # The data is perfectly linearly separable in that space - train_score = Perceptron().fit(X_kpca, y).score(X_kpca, y) + train_score = Perceptron(max_iter=5).fit(X_kpca, y).score(X_kpca, y) assert_equal(train_score, 1.0) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index f4ff680a35..c0a46d6c15 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -65,7 +65,7 @@ def test_classification(): for base_estimator in [None, DummyClassifier(), - Perceptron(), + Perceptron(tol=1e-3), DecisionTreeClassifier(), KNeighborsClassifier(), SVC()]: @@ -519,7 +519,7 @@ def test_base_estimator(): assert_true(isinstance(ensemble.base_estimator_, DecisionTreeClassifier)) - ensemble = BaggingClassifier(Perceptron(), + ensemble = BaggingClassifier(Perceptron(tol=1e-3), n_jobs=3, random_state=0).fit(X_train, y_train) @@ -668,7 +668,7 @@ def test_oob_score_removed_on_warm_start(): def test_oob_score_consistency(): - # Make sure OOB scores are identical when random_state, estimator, and + # Make sure OOB scores are identical when random_state, estimator, and # training data are fixed and fitting is done twice X, y = make_hastie_10_2(n_samples=200, random_state=1) bagging = BaggingClassifier(KNeighborsClassifier(), max_samples=0.5, diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index 6b81dbf674..65ea8b62a2 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -24,8 +24,8 @@ from sklearn.feature_selection import SelectFromModel def test_base(): # Check BaseEnsemble methods. - ensemble = BaggingClassifier(base_estimator=Perceptron(random_state=None), - n_estimators=3) + ensemble = BaggingClassifier( + base_estimator=Perceptron(tol=1e-3, random_state=None), n_estimators=3) iris = load_iris() ensemble.fit(iris.data, iris.target) @@ -46,7 +46,7 @@ def test_base(): assert_true(isinstance(ensemble[2].random_state, int)) assert_not_equal(ensemble[1].random_state, ensemble[2].random_state) - np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(), + np_int_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3), n_estimators=np.int32(3)) np_int_ensemble.fit(iris.data, iris.target) @@ -54,7 +54,7 @@ def test_base(): def test_base_zero_n_estimators(): # Check that instantiating a BaseEnsemble with n_estimators<=0 raises # a ValueError. - ensemble = BaggingClassifier(base_estimator=Perceptron(), + ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3), n_estimators=0) iris = load_iris() assert_raise_message(ValueError, @@ -65,13 +65,13 @@ def test_base_zero_n_estimators(): def test_base_not_int_n_estimators(): # Check that instantiating a BaseEnsemble with a string as n_estimators # raises a ValueError demanding n_estimators to be supplied as an integer. - string_ensemble = BaggingClassifier(base_estimator=Perceptron(), + string_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3), n_estimators='3') iris = load_iris() assert_raise_message(ValueError, "n_estimators must be an integer", string_ensemble.fit, iris.data, iris.target) - float_ensemble = BaggingClassifier(base_estimator=Perceptron(), + float_ensemble = BaggingClassifier(base_estimator=Perceptron(tol=1e-3), n_estimators=3.0) assert_raise_message(ValueError, "n_estimators must be an integer", @@ -82,7 +82,7 @@ def test_set_random_states(): # Linear Discriminant Analysis doesn't have random state: smoke test _set_random_states(LinearDiscriminantAnalysis(), random_state=17) - clf1 = Perceptron(random_state=None) + clf1 = Perceptron(tol=1e-3, random_state=None) assert_equal(clf1.random_state, None) # check random_state is None still sets _set_random_states(clf1, None) @@ -91,15 +91,16 @@ def test_set_random_states(): # check random_state fixes results in consistent initialisation _set_random_states(clf1, 3) assert_true(isinstance(clf1.random_state, int)) - clf2 = Perceptron(random_state=None) + clf2 = Perceptron(tol=1e-3, random_state=None) _set_random_states(clf2, 3) assert_equal(clf1.random_state, clf2.random_state) # nested random_state def make_steps(): - return [('sel', SelectFromModel(Perceptron(random_state=None))), - ('clf', Perceptron(random_state=None))] + return [('sel', SelectFromModel(Perceptron(tol=1e-3, + random_state=None))), + ('clf', Perceptron(tol=1e-3, random_state=None))] est1 = Pipeline(make_steps()) _set_random_states(est1, 3) diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index 6ac6b8630b..ae4d1ba433 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -24,7 +24,8 @@ rng = np.random.RandomState(0) def test_invalid_input(): - clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=None) + clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, + random_state=None, tol=None) for threshold in ["gobbledigook", ".5 * gobbledigook"]: model = SelectFromModel(clf, threshold=threshold) model.fit(data, y) @@ -32,9 +33,7 @@ def test_invalid_input(): def test_input_estimator_unchanged(): - """ - Test that SelectFromModel fits on a clone of the estimator. - """ + # Test that SelectFromModel fits on a clone of the estimator. est = RandomForestClassifier() transformer = SelectFromModel(estimator=est) transformer.fit(data, y) @@ -106,7 +105,8 @@ def test_feature_importances_2d_coef(): def test_partial_fit(): - est = PassiveAggressiveClassifier(random_state=0, shuffle=False) + est = PassiveAggressiveClassifier(random_state=0, shuffle=False, + max_iter=5, tol=None) transformer = SelectFromModel(estimator=est) transformer.partial_fit(data, y, classes=np.unique(y)) @@ -135,12 +135,12 @@ def test_calling_fit_reinitializes(): def test_prefit(): - """ - Test all possible combinations of the prefit parameter. - """ + # Test all possible combinations of the prefit parameter. + # Passing a prefit parameter with the selected model # and fitting a unfit model with prefit=False should give same results. - clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, + random_state=0, tol=None) model = SelectFromModel(clf) model.fit(data, y) X_transform = model.transform(data) @@ -173,8 +173,9 @@ def test_threshold_string(): def test_threshold_without_refitting(): - """Test that the threshold can be set without refitting the model.""" - clf = SGDClassifier(alpha=0.1, n_iter=10, shuffle=True, random_state=0) + # Test that the threshold can be set without refitting the model. + clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, + random_state=0, tol=None) model = SelectFromModel(clf, threshold="0.1 * mean") model.fit(data, y) X_transform = model.transform(data) diff --git a/sklearn/linear_model/passive_aggressive.py b/sklearn/linear_model/passive_aggressive.py index 941f398bd6..ea5c37ad3d 100644 --- a/sklearn/linear_model/passive_aggressive.py +++ b/sklearn/linear_model/passive_aggressive.py @@ -23,7 +23,25 @@ class PassiveAggressiveClassifier(BaseSGDClassifier): n_iter : int, optional The number of passes over the training data (aka epochs). - Defaults to 5. + Defaults to None. Deprecated, will be removed in 0.21. + + .. versionchanged:: 0.19 + Deprecated + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. + Defaults to 5. Defaults to 1000 from 0.21, or if tol is not None. + + .. versionadded:: 0.19 + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to None. + Defaults to 1e-3 from 0.21. + + .. versionadded:: 0.19 shuffle : bool, default=True Whether or not the training data should be shuffled after each epoch. @@ -83,6 +101,10 @@ class PassiveAggressiveClassifier(BaseSGDClassifier): intercept_ : array, shape = [1] if n_classes == 2 else [n_classes] Constants in decision function. + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + For multiclass fits, it is the maximum over every binary fit. + See also -------- @@ -96,14 +118,15 @@ class PassiveAggressiveClassifier(BaseSGDClassifier): K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006) """ - - def __init__(self, C=1.0, fit_intercept=True, n_iter=5, shuffle=True, - verbose=0, loss="hinge", n_jobs=1, random_state=None, - warm_start=False, class_weight=None, average=False): + def __init__(self, C=1.0, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, loss="hinge", n_jobs=1, + random_state=None, warm_start=False, class_weight=None, + average=False, n_iter=None): super(PassiveAggressiveClassifier, self).__init__( penalty=None, fit_intercept=fit_intercept, - n_iter=n_iter, + max_iter=max_iter, + tol=tol, shuffle=shuffle, verbose=verbose, random_state=random_state, @@ -111,7 +134,9 @@ class PassiveAggressiveClassifier(BaseSGDClassifier): warm_start=warm_start, class_weight=class_weight, average=average, - n_jobs=n_jobs) + n_jobs=n_jobs, + n_iter=n_iter) + self.C = C self.loss = loss @@ -150,7 +175,7 @@ class PassiveAggressiveClassifier(BaseSGDClassifier): "parameter.") lr = "pa1" if self.loss == "hinge" else "pa2" return self._partial_fit(X, y, alpha=1.0, C=self.C, - loss="hinge", learning_rate=lr, n_iter=1, + loss="hinge", learning_rate=lr, max_iter=1, classes=classes, sample_weight=None, coef_init=None, intercept_init=None) @@ -202,7 +227,25 @@ class PassiveAggressiveRegressor(BaseSGDRegressor): n_iter : int, optional The number of passes over the training data (aka epochs). - Defaults to 5. + Defaults to None. Deprecated, will be removed in 0.21. + + .. versionchanged:: 0.19 + Deprecated + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. + Defaults to 5. Defaults to 1000 from 0.21, or if tol is not None. + + .. versionadded:: 0.19 + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to None. + Defaults to 1e-3 from 0.21. + + .. versionadded:: 0.19 shuffle : bool, default=True Whether or not the training data should be shuffled after each epoch. @@ -245,6 +288,9 @@ class PassiveAggressiveRegressor(BaseSGDRegressor): intercept_ : array, shape = [1] if n_classes == 2 else [n_classes] Constants in decision function. + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + See also -------- @@ -257,22 +303,24 @@ class PassiveAggressiveRegressor(BaseSGDRegressor): K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR (2006) """ - def __init__(self, C=1.0, fit_intercept=True, n_iter=5, shuffle=True, - verbose=0, loss="epsilon_insensitive", + def __init__(self, C=1.0, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, loss="epsilon_insensitive", epsilon=DEFAULT_EPSILON, random_state=None, warm_start=False, - average=False): + average=False, n_iter=None): super(PassiveAggressiveRegressor, self).__init__( penalty=None, l1_ratio=0, epsilon=epsilon, eta0=1.0, fit_intercept=fit_intercept, - n_iter=n_iter, + max_iter=max_iter, + tol=tol, shuffle=shuffle, verbose=verbose, random_state=random_state, warm_start=warm_start, - average=average) + average=average, + n_iter=n_iter) self.C = C self.loss = loss @@ -294,7 +342,7 @@ class PassiveAggressiveRegressor(BaseSGDRegressor): lr = "pa1" if self.loss == "epsilon_insensitive" else "pa2" return self._partial_fit(X, y, alpha=1.0, C=self.C, loss="epsilon_insensitive", - learning_rate=lr, n_iter=1, + learning_rate=lr, max_iter=1, sample_weight=None, coef_init=None, intercept_init=None) diff --git a/sklearn/linear_model/perceptron.py b/sklearn/linear_model/perceptron.py index 0b11049fc3..0edfa28712 100644 --- a/sklearn/linear_model/perceptron.py +++ b/sklearn/linear_model/perceptron.py @@ -25,7 +25,25 @@ class Perceptron(BaseSGDClassifier): n_iter : int, optional The number of passes over the training data (aka epochs). - Defaults to 5. + Defaults to None. Deprecated, will be removed in 0.21. + + .. versionchanged:: 0.19 + Deprecated + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. + Defaults to 5. Defaults to 1000 from 0.21, or if tol is not None. + + .. versionadded:: 0.19 + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to None. + Defaults to 1e-3 from 0.21. + + .. versionadded:: 0.19 shuffle : bool, optional, default True Whether or not the training data should be shuffled after each epoch. @@ -71,6 +89,10 @@ class Perceptron(BaseSGDClassifier): intercept_ : array, shape = [1] if n_classes == 2 else [n_classes] Constants in decision function. + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + For multiclass fits, it is the maximum over every binary fit. + Notes ----- @@ -89,13 +111,15 @@ class Perceptron(BaseSGDClassifier): https://en.wikipedia.org/wiki/Perceptron and references therein. """ def __init__(self, penalty=None, alpha=0.0001, fit_intercept=True, - n_iter=5, shuffle=True, verbose=0, eta0=1.0, n_jobs=1, - random_state=0, class_weight=None, warm_start=False): + max_iter=None, tol=None, shuffle=True, verbose=0, eta0=1.0, + n_jobs=1, random_state=0, class_weight=None, + warm_start=False, n_iter=None): super(Perceptron, self).__init__(loss="perceptron", penalty=penalty, alpha=alpha, l1_ratio=0, fit_intercept=fit_intercept, - n_iter=n_iter, + max_iter=max_iter, + tol=tol, shuffle=shuffle, verbose=verbose, random_state=random_state, @@ -104,4 +128,5 @@ class Perceptron(BaseSGDClassifier): power_t=0.5, warm_start=warm_start, class_weight=class_weight, - n_jobs=n_jobs) + n_jobs=n_jobs, + n_iter=n_iter) diff --git a/sklearn/linear_model/sgd_fast.pyx b/sklearn/linear_model/sgd_fast.pyx index 01718aaf15..1e4027b7f8 100644 --- a/sklearn/linear_model/sgd_fast.pyx +++ b/sklearn/linear_model/sgd_fast.pyx @@ -17,6 +17,7 @@ from time import time cimport cython from libc.math cimport exp, log, sqrt, pow, fabs cimport numpy as np +from numpy.math cimport INFINITY cdef extern from "sgd_fast_helpers.h": bint skl_isfinite(double) nogil @@ -38,6 +39,7 @@ DEF INVSCALING = 3 DEF PA1 = 4 DEF PA2 = 5 + # ---------------------------------------- # Extension Types for Loss Functions # ---------------------------------------- @@ -335,7 +337,7 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, double alpha, double C, double l1_ratio, SequentialDataset dataset, - int n_iter, int fit_intercept, + int max_iter, double tol, int fit_intercept, int verbose, bint shuffle, np.uint32_t seed, double weight_pos, double weight_neg, int learning_rate, double eta0, @@ -363,8 +365,10 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1. dataset : SequentialDataset A concrete ``SequentialDataset`` object. - n_iter : int - The number of iterations (epochs). + max_iter : int + The maximum number of iterations (epochs). + tol: double + The tolerance for the stopping criterion fit_intercept : int Whether or not to fit the intercept (1 or 0). verbose : int @@ -399,26 +403,28 @@ def plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, The fitted weight vector. intercept : float The fitted intercept term. + n_iter_ : int + The actual number of iter (epochs). """ standard_weights, standard_intercept,\ - _, _ = _plain_sgd(weights, - intercept, - None, - 0, - loss, - penalty_type, - alpha, C, - l1_ratio, - dataset, - n_iter, fit_intercept, - verbose, shuffle, seed, - weight_pos, weight_neg, - learning_rate, eta0, - power_t, - t, - intercept_decay, - 0) - return standard_weights, standard_intercept + _, _, n_iter_ = _plain_sgd(weights, + intercept, + None, + 0, + loss, + penalty_type, + alpha, C, + l1_ratio, + dataset, + max_iter, tol, fit_intercept, + verbose, shuffle, seed, + weight_pos, weight_neg, + learning_rate, eta0, + power_t, + t, + intercept_decay, + 0) + return standard_weights, standard_intercept, n_iter_ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, @@ -430,7 +436,7 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, double alpha, double C, double l1_ratio, SequentialDataset dataset, - int n_iter, int fit_intercept, + int max_iter, double tol, int fit_intercept, int verbose, bint shuffle, np.uint32_t seed, double weight_pos, double weight_neg, int learning_rate, double eta0, @@ -463,8 +469,10 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1. dataset : SequentialDataset A concrete ``SequentialDataset`` object. - n_iter : int - The number of iterations (epochs). + max_iter : int + The maximum number of iterations (epochs). + tol: double + The tolerance for the stopping criterion. fit_intercept : int Whether or not to fit the intercept (1 or 0). verbose : int @@ -506,6 +514,8 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, The averaged weights across iterations average_intercept : float The averaged intercept across iterations + n_iter_ : int + The actual number of iter (epochs). """ return _plain_sgd(weights, intercept, @@ -516,7 +526,7 @@ def average_sgd(np.ndarray[double, ndim=1, mode='c'] weights, alpha, C, l1_ratio, dataset, - n_iter, fit_intercept, + max_iter, tol, fit_intercept, verbose, shuffle, seed, weight_pos, weight_neg, learning_rate, eta0, @@ -535,7 +545,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, double alpha, double C, double l1_ratio, SequentialDataset dataset, - int n_iter, int fit_intercept, + int max_iter, double tol, int fit_intercept, int verbose, bint shuffle, np.uint32_t seed, double weight_pos, double weight_neg, int learning_rate, double eta0, @@ -561,6 +571,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef double p = 0.0 cdef double update = 0.0 cdef double sumloss = 0.0 + cdef double previous_loss = np.inf cdef double y = 0.0 cdef double sample_weight cdef double class_weight = 1.0 @@ -571,6 +582,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef double optimal_init = 0.0 cdef double dloss = 0.0 cdef double MAX_DLOSS = 1e12 + cdef double max_change = 0.0 + cdef double max_weight = 0.0 # q vector is only used for L1 regularization cdef np.ndarray[double, ndim = 1, mode = "c"] q = None @@ -596,7 +609,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, t_start = time() with nogil: - for epoch in range(n_iter): + for epoch in range(max_iter): + sumloss = 0 if verbose > 0: with gil: print("-- Epoch %d" % (epoch + 1)) @@ -612,8 +626,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, elif learning_rate == INVSCALING: eta = eta0 / pow(t, power_t) - if verbose > 0: - sumloss += loss.loss(p, y) + sumloss += loss.loss(p, y) if y > 0.0: class_weight = weight_pos @@ -677,10 +690,10 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, # report epoch information if verbose > 0: with gil: - print("Norm: %.2f, NNZs: %d, " - "Bias: %.6f, T: %d, Avg. loss: %.6f" + print("Norm: %.2f, NNZs: %d, Bias: %.6f, T: %d, " + "Avg. loss: %f" % (w.norm(), weights.nonzero()[0].shape[0], - intercept, count, sumloss / count)) + intercept, count, sumloss / n_samples)) print("Total training time: %.2f seconds." % (time() - t_start)) @@ -690,6 +703,14 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, infinity = True break + if tol > -INFINITY and sumloss > previous_loss - tol * n_samples: + if verbose: + with gil: + print("Convergence after %d epochs took %.2f seconds" + % (epoch + 1, time() - t_start)) + break + previous_loss = sumloss + if infinity: raise ValueError(("Floating-point under-/overflow occurred at epoch" " #%d. Scaling input data with StandardScaler or" @@ -697,7 +718,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, w.reset_wscale() - return weights, intercept, average_weights, average_intercept + return weights, intercept, average_weights, average_intercept, epoch + 1 cdef bint any_nonfinite(double *w, int n) nogil: diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 85f2b8ef7d..13b5de535d 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -5,6 +5,7 @@ """Classification and regression using Stochastic Gradient Descent (SGD).""" import numpy as np +import warnings from abc import ABCMeta, abstractmethod @@ -17,6 +18,7 @@ from ..utils import check_array, check_random_state, check_X_y from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted +from ..exceptions import ConvergenceWarning from ..externals import six from .sgd_fast import plain_sgd, average_sgd @@ -45,10 +47,10 @@ class BaseSGD(six.with_metaclass(ABCMeta, BaseEstimator, SparseCoefMixin)): """Base class for SGD classification and regression.""" def __init__(self, loss, penalty='l2', alpha=0.0001, C=1.0, - l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=True, - verbose=0, epsilon=0.1, random_state=None, + l1_ratio=0.15, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, epsilon=0.1, random_state=None, learning_rate="optimal", eta0=0.0, power_t=0.5, - warm_start=False, average=False): + warm_start=False, average=False, n_iter=None): self.loss = loss self.penalty = penalty self.learning_rate = learning_rate @@ -57,7 +59,6 @@ class BaseSGD(six.with_metaclass(ABCMeta, BaseEstimator, SparseCoefMixin)): self.C = C self.l1_ratio = l1_ratio self.fit_intercept = fit_intercept - self.n_iter = n_iter self.shuffle = shuffle self.random_state = random_state self.verbose = verbose @@ -66,6 +67,28 @@ class BaseSGD(six.with_metaclass(ABCMeta, BaseEstimator, SparseCoefMixin)): self.warm_start = warm_start self.average = average + if n_iter is not None: + warnings.warn("n_iter parameter is deprecated in 0.19 and will be" + " removed in 0.21. Use max_iter and tol instead.", + DeprecationWarning) + # Same behavior as before 0.19 + self.max_iter = n_iter + tol = None + + elif tol is None and max_iter is None: + warnings.warn( + "max_iter and tol parameters have been added in %s in 0.19. If" + "both are left unset, they default to max_iter=5 and tol=None." + " If tol is not None, max_iter defaults to max_iter=1000. " + "From 0.21, default max_iter will be 1000, " + "and default tol will be 1e-3." % type(self), FutureWarning) + # Before 0.19, default was n_iter=5 + self.max_iter = 5 + else: + self.max_iter = max_iter if max_iter is not None else 1000 + + self.tol = tol + self._validate_params() def set_params(self, *args, **kwargs): @@ -81,8 +104,8 @@ class BaseSGD(six.with_metaclass(ABCMeta, BaseEstimator, SparseCoefMixin)): """Validate input params. """ if not isinstance(self.shuffle, bool): raise ValueError("shuffle must be either True or False") - if self.n_iter <= 0: - raise ValueError("n_iter must be > zero") + if self.max_iter <= 0: + raise ValueError("max_iter must be > zero. Got %f" % self.max_iter) if not (0.0 <= self.l1_ratio <= 1.0): raise ValueError("l1_ratio must be in [0, 1]") if self.alpha < 0.0: @@ -235,7 +258,7 @@ def _prepare_fit_binary(est, y, i): return y_i, coef, intercept, average_coef, average_intercept -def fit_binary(est, i, X, y, alpha, C, learning_rate, n_iter, +def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter, pos_weight, neg_weight, sample_weight): """Fit a single binary classifier. @@ -257,35 +280,35 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, n_iter, # Windows seed = random_state.randint(0, np.iinfo(np.int32).max) + tol = est.tol if est.tol is not None else -np.inf + if not est.average: return plain_sgd(coef, intercept, est.loss_function_, penalty_type, alpha, C, est.l1_ratio, - dataset, n_iter, int(est.fit_intercept), + dataset, max_iter, tol, int(est.fit_intercept), int(est.verbose), int(est.shuffle), seed, pos_weight, neg_weight, learning_rate_type, est.eta0, est.power_t, est.t_, intercept_decay) else: - standard_coef, standard_intercept, average_coef, \ - average_intercept = average_sgd(coef, intercept, average_coef, - average_intercept, - est.loss_function_, penalty_type, - alpha, C, est.l1_ratio, dataset, - n_iter, int(est.fit_intercept), - int(est.verbose), int(est.shuffle), - seed, pos_weight, neg_weight, - learning_rate_type, est.eta0, - est.power_t, est.t_, - intercept_decay, - est.average) + standard_coef, standard_intercept, average_coef, average_intercept, \ + n_iter_ = average_sgd(coef, intercept, average_coef, + average_intercept, est.loss_function_, + penalty_type, alpha, C, est.l1_ratio, + dataset, max_iter, tol, + int(est.fit_intercept), int(est.verbose), + int(est.shuffle), seed, pos_weight, + neg_weight, learning_rate_type, est.eta0, + est.power_t, est.t_, intercept_decay, + est.average) if len(est.classes_) == 2: est.average_intercept_[0] = average_intercept else: est.average_intercept_[i] = average_intercept - return standard_coef, standard_intercept + return standard_coef, standard_intercept, n_iter_ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, @@ -305,23 +328,26 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, } @abstractmethod - def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15, - fit_intercept=True, n_iter=5, shuffle=True, verbose=0, - epsilon=DEFAULT_EPSILON, n_jobs=1, random_state=None, - learning_rate="optimal", eta0=0.0, power_t=0.5, - class_weight=None, warm_start=False, average=False): + def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, + l1_ratio=0.15, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, epsilon=DEFAULT_EPSILON, n_jobs=1, + random_state=None, learning_rate="optimal", eta0=0.0, + power_t=0.5, class_weight=None, warm_start=False, + average=False, n_iter=None): super(BaseSGDClassifier, self).__init__(loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio, fit_intercept=fit_intercept, - n_iter=n_iter, shuffle=shuffle, + max_iter=max_iter, tol=tol, + shuffle=shuffle, verbose=verbose, epsilon=epsilon, random_state=random_state, learning_rate=learning_rate, eta0=eta0, power_t=power_t, warm_start=warm_start, - average=average) + average=average, + n_iter=n_iter) self.class_weight = class_weight self.n_jobs = int(n_jobs) @@ -332,7 +358,7 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, return self.loss_function_ def _partial_fit(self, X, y, alpha, C, - loss, learning_rate, n_iter, + loss, learning_rate, max_iter, classes, sample_weight, coef_init, intercept_init): X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C") @@ -364,11 +390,13 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, if n_classes > 2: self._fit_multiclass(X, y, alpha=alpha, C=C, learning_rate=learning_rate, - sample_weight=sample_weight, n_iter=n_iter) + sample_weight=sample_weight, + max_iter=max_iter) elif n_classes == 2: self._fit_binary(X, y, alpha=alpha, C=C, learning_rate=learning_rate, - sample_weight=sample_weight, n_iter=n_iter) + sample_weight=sample_weight, + max_iter=max_iter) else: raise ValueError("The number of class labels must be " "greater than one.") @@ -405,21 +433,28 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, # Clear iteration count for multiple call to fit. self.t_ = 1.0 - self._partial_fit(X, y, alpha, C, loss, learning_rate, self.n_iter, + self._partial_fit(X, y, alpha, C, loss, learning_rate, self.max_iter, classes, sample_weight, coef_init, intercept_init) + if (self.tol is not None and self.tol > -np.inf + and self.n_iter_ == self.max_iter): + warnings.warn("Maximum number of iteration reached before " + "convergence. Consider increasing max_iter to " + "improve the fit.", + ConvergenceWarning) return self def _fit_binary(self, X, y, alpha, C, sample_weight, - learning_rate, n_iter): + learning_rate, max_iter): """Fit a binary classifier on X and y. """ - coef, intercept = fit_binary(self, 1, X, y, alpha, C, - learning_rate, n_iter, - self._expanded_class_weight[1], - self._expanded_class_weight[0], - sample_weight) + coef, intercept, n_iter_ = fit_binary(self, 1, X, y, alpha, C, + learning_rate, max_iter, + self._expanded_class_weight[1], + self._expanded_class_weight[0], + sample_weight) - self.t_ += n_iter * X.shape[0] + self.t_ += n_iter_ * X.shape[0] + self.n_iter_ = n_iter_ # need to be 2d if self.average > 0: @@ -436,7 +471,7 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, self.intercept_ = np.atleast_1d(intercept) def _fit_multiclass(self, X, y, alpha, C, learning_rate, - sample_weight, n_iter): + sample_weight, max_iter): """Fit a multi-class classifier by combining binary classifiers Each binary classifier predicts one class versus all others. This @@ -446,14 +481,18 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, result = Parallel(n_jobs=self.n_jobs, backend="threading", verbose=self.verbose)( delayed(fit_binary)(self, i, X, y, alpha, C, learning_rate, - n_iter, self._expanded_class_weight[i], 1., - sample_weight) + max_iter, self._expanded_class_weight[i], + 1., sample_weight) for i in range(len(self.classes_))) - for i, (_, intercept) in enumerate(result): + # take the maximum of n_iter_ over every binary fit + n_iter_ = 0. + for i, (_, intercept, n_iter_i) in enumerate(result): self.intercept_[i] = intercept + n_iter_ = max(n_iter_, n_iter_i) - self.t_ += n_iter * X.shape[0] + self.t_ += n_iter_ * X.shape[0] + self.n_iter_ = n_iter_ if self.average > 0: if self.average <= self.t_ - 1.0: @@ -501,7 +540,7 @@ class BaseSGDClassifier(six.with_metaclass(ABCMeta, BaseSGD, "Pass the resulting weights as the class_weight " "parameter.".format(self.class_weight)) return self._partial_fit(X, y, alpha=self.alpha, C=1.0, loss=self.loss, - learning_rate=self.learning_rate, n_iter=1, + learning_rate=self.learning_rate, max_iter=1, classes=classes, sample_weight=sample_weight, coef_init=None, intercept_init=None) @@ -599,9 +638,26 @@ class SGDClassifier(BaseSGDClassifier): data is assumed to be already centered. Defaults to True. n_iter : int, optional - The number of passes over the training data (aka epochs). The number - of iterations is set to 1 if using partial_fit. - Defaults to 5. + The number of passes over the training data (aka epochs). + Defaults to None. Deprecated, will be removed in 0.21. + + .. versionchanged:: 0.19 + Deprecated + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. + Defaults to 5. Defaults to 1000 from 0.21, or if tol is not None. + + .. versionadded:: 0.19 + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to None. + Defaults to 1e-3 from 0.21. + + .. versionadded:: 0.19 shuffle : bool, optional Whether or not the training data should be shuffled after each epoch. @@ -677,6 +733,10 @@ class SGDClassifier(BaseSGDClassifier): intercept_ : array, shape (1,) if n_classes == 2 else (n_classes,) Constants in decision function. + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + For multiclass fits, it is the maximum over every binary fit. + loss_function_ : concrete ``LossFunction`` Examples @@ -689,10 +749,11 @@ class SGDClassifier(BaseSGDClassifier): >>> clf.fit(X, Y) ... #doctest: +NORMALIZE_WHITESPACE SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1, - eta0=0.0, fit_intercept=True, l1_ratio=0.15, - learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1, - penalty='l2', power_t=0.5, random_state=None, shuffle=True, - verbose=0, warm_start=False) + eta0=0.0, fit_intercept=True, l1_ratio=0.15, + learning_rate='optimal', loss='hinge', max_iter=5, n_iter=None, + n_jobs=1, penalty='l2', power_t=0.5, random_state=None, + shuffle=True, tol=None, verbose=0, warm_start=False) + >>> print(clf.predict([[-0.8, -1]])) [1] @@ -703,17 +764,18 @@ class SGDClassifier(BaseSGDClassifier): """ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15, - fit_intercept=True, n_iter=5, shuffle=True, verbose=0, - epsilon=DEFAULT_EPSILON, n_jobs=1, random_state=None, - learning_rate="optimal", eta0=0.0, power_t=0.5, - class_weight=None, warm_start=False, average=False): + fit_intercept=True, max_iter=None, tol=None, shuffle=True, + verbose=0, epsilon=DEFAULT_EPSILON, n_jobs=1, + random_state=None, learning_rate="optimal", eta0=0.0, + power_t=0.5, class_weight=None, warm_start=False, + average=False, n_iter=None): super(SGDClassifier, self).__init__( loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio, - fit_intercept=fit_intercept, n_iter=n_iter, shuffle=shuffle, - verbose=verbose, epsilon=epsilon, n_jobs=n_jobs, + fit_intercept=fit_intercept, max_iter=max_iter, tol=tol, + shuffle=shuffle, verbose=verbose, epsilon=epsilon, n_jobs=n_jobs, random_state=random_state, learning_rate=learning_rate, eta0=eta0, power_t=power_t, class_weight=class_weight, warm_start=warm_start, - average=average) + average=average, n_iter=n_iter) def _check_proba(self): check_is_fitted(self, "t_") @@ -843,25 +905,26 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): @abstractmethod def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001, - l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=True, - verbose=0, epsilon=DEFAULT_EPSILON, random_state=None, - learning_rate="invscaling", eta0=0.01, power_t=0.25, - warm_start=False, average=False): + l1_ratio=0.15, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, epsilon=DEFAULT_EPSILON, + random_state=None, learning_rate="invscaling", eta0=0.01, + power_t=0.25, warm_start=False, average=False, n_iter=None): super(BaseSGDRegressor, self).__init__(loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio, fit_intercept=fit_intercept, - n_iter=n_iter, shuffle=shuffle, + max_iter=max_iter, tol=tol, + shuffle=shuffle, verbose=verbose, epsilon=epsilon, random_state=random_state, learning_rate=learning_rate, eta0=eta0, power_t=power_t, warm_start=warm_start, - average=average) + average=average, + n_iter=n_iter) def _partial_fit(self, X, y, alpha, C, loss, learning_rate, - n_iter, sample_weight, - coef_init, intercept_init): + max_iter, sample_weight, coef_init, intercept_init): X, y = check_X_y(X, y, "csr", copy=False, order='C', dtype=np.float64) y = y.astype(np.float64, copy=False) @@ -887,7 +950,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): order="C") self._fit_regressor(X, y, alpha, C, loss, learning_rate, - sample_weight, n_iter) + sample_weight, max_iter) return self @@ -912,9 +975,9 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): """ return self._partial_fit(X, y, self.alpha, C=1.0, loss=self.loss, - learning_rate=self.learning_rate, n_iter=1, - sample_weight=sample_weight, - coef_init=None, intercept_init=None) + learning_rate=self.learning_rate, max_iter=1, + sample_weight=sample_weight, coef_init=None, + intercept_init=None) def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None, intercept_init=None, sample_weight=None): @@ -936,9 +999,18 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): # Clear iteration count for multiple call to fit. self.t_ = 1.0 - return self._partial_fit(X, y, alpha, C, loss, learning_rate, - self.n_iter, sample_weight, - coef_init, intercept_init) + self._partial_fit(X, y, alpha, C, loss, learning_rate, + self.max_iter, sample_weight, coef_init, + intercept_init) + + if (self.tol is not None and self.tol > -np.inf + and self.n_iter_ == self.max_iter): + warnings.warn("Maximum number of iteration reached before " + "convergence. Consider increasing max_iter to " + "improve the fit.", + ConvergenceWarning) + + return self def fit(self, X, y, coef_init=None, intercept_init=None, sample_weight=None): @@ -1006,7 +1078,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): return self._decision_function(X) def _fit_regressor(self, X, y, alpha, C, loss, learning_rate, - sample_weight, n_iter): + sample_weight, max_iter): dataset, intercept_decay = make_dataset(X, y, sample_weight) loss_function = self._get_loss_function(loss) @@ -1021,9 +1093,11 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): # Windows seed = random_state.randint(0, np.iinfo(np.int32).max) + tol = self.tol if self.tol is not None else -np.inf + if self.average > 0: self.standard_coef_, self.standard_intercept_, \ - self.average_coef_, self.average_intercept_ =\ + self.average_coef_, self.average_intercept_, self.n_iter_ =\ average_sgd(self.standard_coef_, self.standard_intercept_[0], self.average_coef_, @@ -1033,7 +1107,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): alpha, C, self.l1_ratio, dataset, - n_iter, + max_iter, tol, int(self.fit_intercept), int(self.verbose), int(self.shuffle), @@ -1045,7 +1119,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): self.average_intercept_ = np.atleast_1d(self.average_intercept_) self.standard_intercept_ = np.atleast_1d(self.standard_intercept_) - self.t_ += n_iter * X.shape[0] + self.t_ += self.n_iter_ * X.shape[0] if self.average <= self.t_ - 1.0: self.coef_ = self.average_coef_ @@ -1055,7 +1129,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): self.intercept_ = self.standard_intercept_ else: - self.coef_, self.intercept_ = \ + self.coef_, self.intercept_, self.n_iter_ = \ plain_sgd(self.coef_, self.intercept_[0], loss_function, @@ -1063,7 +1137,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): alpha, C, self.l1_ratio, dataset, - n_iter, + max_iter, tol, int(self.fit_intercept), int(self.verbose), int(self.shuffle), @@ -1073,7 +1147,7 @@ class BaseSGDRegressor(BaseSGD, RegressorMixin): self.eta0, self.power_t, self.t_, intercept_decay) - self.t_ += n_iter * X.shape[0] + self.t_ += self.n_iter_ * X.shape[0] self.intercept_ = np.atleast_1d(self.intercept_) @@ -1128,9 +1202,26 @@ class SGDRegressor(BaseSGDRegressor): data is assumed to be already centered. Defaults to True. n_iter : int, optional - The number of passes over the training data (aka epochs). The number - of iterations is set to 1 if using partial_fit. - Defaults to 5. + The number of passes over the training data (aka epochs). + Defaults to None. Deprecated, will be removed in 0.21. + + .. versionchanged:: 0.19 + Deprecated + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. + Defaults to 5. Defaults to 1000 from 0.21, or if tol is not None. + + .. versionadded:: 0.19 + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to None. + Defaults to 1e-3 from 0.21. + + .. versionadded:: 0.19 shuffle : bool, optional Whether or not the training data should be shuffled after each epoch. @@ -1194,6 +1285,9 @@ class SGDRegressor(BaseSGDRegressor): average_intercept_ : array, shape (1,) The averaged intercept term. + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + Examples -------- >>> import numpy as np @@ -1206,9 +1300,11 @@ class SGDRegressor(BaseSGDRegressor): >>> clf.fit(X, y) ... #doctest: +NORMALIZE_WHITESPACE SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01, - fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling', - loss='squared_loss', n_iter=5, penalty='l2', power_t=0.25, - random_state=None, shuffle=True, verbose=0, warm_start=False) + fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling', + loss='squared_loss', max_iter=5, n_iter=None, penalty='l2', + power_t=0.25, random_state=None, shuffle=True, tol=None, + verbose=0, warm_start=False) + See also -------- @@ -1216,18 +1312,19 @@ class SGDRegressor(BaseSGDRegressor): """ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001, - l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=True, - verbose=0, epsilon=DEFAULT_EPSILON, random_state=None, - learning_rate="invscaling", eta0=0.01, power_t=0.25, - warm_start=False, average=False): + l1_ratio=0.15, fit_intercept=True, max_iter=None, tol=None, + shuffle=True, verbose=0, epsilon=DEFAULT_EPSILON, + random_state=None, learning_rate="invscaling", eta0=0.01, + power_t=0.25, warm_start=False, average=False, n_iter=None): super(SGDRegressor, self).__init__(loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio, fit_intercept=fit_intercept, - n_iter=n_iter, shuffle=shuffle, + max_iter=max_iter, tol=tol, + shuffle=shuffle, verbose=verbose, epsilon=epsilon, random_state=random_state, learning_rate=learning_rate, eta0=eta0, power_t=power_t, warm_start=warm_start, - average=average) + average=average, n_iter=n_iter) diff --git a/sklearn/linear_model/tests/test_huber.py b/sklearn/linear_model/tests/test_huber.py index 9431e96f74..08f4fdf281 100644 --- a/sklearn/linear_model/tests/test_huber.py +++ b/sklearn/linear_model/tests/test_huber.py @@ -118,8 +118,7 @@ def test_huber_sparse(): def test_huber_scaling_invariant(): - """Test that outliers filtering is scaling independent.""" - rng = np.random.RandomState(0) + # Test that outliers filtering is scaling independent. X, y = make_regression_with_outliers() huber = HuberRegressor(fit_intercept=False, alpha=0.0, max_iter=100) huber.fit(X, y) @@ -136,7 +135,7 @@ def test_huber_scaling_invariant(): def test_huber_and_sgd_same_results(): - """Test they should converge to same coefficients for same parameters""" + # Test they should converge to same coefficients for same parameters X, y = make_regression_with_outliers(n_samples=10, n_features=2) @@ -151,8 +150,8 @@ def test_huber_and_sgd_same_results(): assert_almost_equal(huber.scale_, 1.0, 3) sgdreg = SGDRegressor( - alpha=0.0, loss="huber", shuffle=True, random_state=0, n_iter=10000, - fit_intercept=False, epsilon=1.35) + alpha=0.0, loss="huber", shuffle=True, random_state=0, max_iter=10000, + fit_intercept=False, epsilon=1.35, tol=None) sgdreg.fit(X_scale, y_scale) assert_array_almost_equal(huber.coef_, sgdreg.coef_, 1) diff --git a/sklearn/linear_model/tests/test_passive_aggressive.py b/sklearn/linear_model/tests/test_passive_aggressive.py index a1dc5c4d68..5620c29e18 100644 --- a/sklearn/linear_model/tests/test_passive_aggressive.py +++ b/sklearn/linear_model/tests/test_passive_aggressive.py @@ -71,10 +71,9 @@ def test_classifier_accuracy(): for data in (X, X_csr): for fit_intercept in (True, False): for average in (False, True): - clf = PassiveAggressiveClassifier(C=1.0, n_iter=30, - fit_intercept=fit_intercept, - random_state=0, - average=average) + clf = PassiveAggressiveClassifier( + C=1.0, max_iter=30, fit_intercept=fit_intercept, + random_state=0, average=average, tol=None) clf.fit(data, y) score = clf.score(data, y) assert_greater(score, 0.79) @@ -89,10 +88,9 @@ def test_classifier_partial_fit(): classes = np.unique(y) for data in (X, X_csr): for average in (False, True): - clf = PassiveAggressiveClassifier(C=1.0, - fit_intercept=True, - random_state=0, - average=average) + clf = PassiveAggressiveClassifier( + C=1.0, fit_intercept=True, random_state=0, + average=average, max_iter=5) for t in range(30): clf.partial_fit(data, y, classes) score = clf.score(data, y) @@ -106,7 +104,7 @@ def test_classifier_partial_fit(): def test_classifier_refit(): # Classifier can be retrained on different labels and features. - clf = PassiveAggressiveClassifier().fit(X, y) + clf = PassiveAggressiveClassifier(max_iter=5).fit(X, y) assert_array_equal(clf.classes_, np.unique(y)) clf.fit(X[:, :-1], iris.target_names[y]) @@ -119,24 +117,21 @@ def test_classifier_correctness(): for loss in ("hinge", "squared_hinge"): - clf1 = MyPassiveAggressive(C=1.0, - loss=loss, - fit_intercept=True, - n_iter=2) + clf1 = MyPassiveAggressive( + C=1.0, loss=loss, fit_intercept=True, n_iter=2) clf1.fit(X, y_bin) for data in (X, X_csr): - clf2 = PassiveAggressiveClassifier(C=1.0, - loss=loss, - fit_intercept=True, - n_iter=2, shuffle=False) + clf2 = PassiveAggressiveClassifier( + C=1.0, loss=loss, fit_intercept=True, max_iter=2, + shuffle=False, tol=None) clf2.fit(data, y_bin) assert_array_almost_equal(clf1.w, clf2.coef_.ravel(), decimal=2) def test_classifier_undefined_methods(): - clf = PassiveAggressiveClassifier() + clf = PassiveAggressiveClassifier(max_iter=100) for meth in ("predict_proba", "predict_log_proba", "transform"): assert_raises(AttributeError, lambda x: getattr(clf, x), meth) @@ -147,13 +142,13 @@ def test_class_weights(): [1.0, 1.0], [1.0, 0.0]]) y2 = [1, 1, 1, -1, -1] - clf = PassiveAggressiveClassifier(C=0.1, n_iter=100, class_weight=None, + clf = PassiveAggressiveClassifier(C=0.1, max_iter=100, class_weight=None, random_state=100) clf.fit(X2, y2) assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1])) # we give a small weights to class 1 - clf = PassiveAggressiveClassifier(C=0.1, n_iter=100, + clf = PassiveAggressiveClassifier(C=0.1, max_iter=100, class_weight={1: 0.001}, random_state=100) clf.fit(X2, y2) @@ -165,23 +160,24 @@ def test_class_weights(): def test_partial_fit_weight_class_balanced(): # partial_fit with class_weight='balanced' not supported - clf = PassiveAggressiveClassifier(class_weight="balanced") + clf = PassiveAggressiveClassifier(class_weight="balanced", max_iter=100) assert_raises(ValueError, clf.partial_fit, X, y, classes=np.unique(y)) def test_equal_class_weight(): X2 = [[1, 0], [1, 0], [0, 1], [0, 1]] y2 = [0, 0, 1, 1] - clf = PassiveAggressiveClassifier(C=0.1, n_iter=1000, class_weight=None) + clf = PassiveAggressiveClassifier( + C=0.1, max_iter=1000, tol=None, class_weight=None) clf.fit(X2, y2) # Already balanced, so "balanced" weights should have no effect - clf_balanced = PassiveAggressiveClassifier(C=0.1, n_iter=1000, - class_weight="balanced") + clf_balanced = PassiveAggressiveClassifier( + C=0.1, max_iter=1000, tol=None, class_weight="balanced") clf_balanced.fit(X2, y2) - clf_weighted = PassiveAggressiveClassifier(C=0.1, n_iter=1000, - class_weight={0: 0.5, 1: 0.5}) + clf_weighted = PassiveAggressiveClassifier( + C=0.1, max_iter=1000, tol=None, class_weight={0: 0.5, 1: 0.5}) clf_weighted.fit(X2, y2) # should be similar up to some epsilon due to learning rate schedule @@ -195,7 +191,7 @@ def test_wrong_class_weight_label(): [1.0, 1.0], [1.0, 0.0]]) y2 = [1, 1, 1, -1, -1] - clf = PassiveAggressiveClassifier(class_weight={0: 0.5}) + clf = PassiveAggressiveClassifier(class_weight={0: 0.5}, max_iter=100) assert_raises(ValueError, clf.fit, X2, y2) @@ -205,10 +201,10 @@ def test_wrong_class_weight_format(): [1.0, 1.0], [1.0, 0.0]]) y2 = [1, 1, 1, -1, -1] - clf = PassiveAggressiveClassifier(class_weight=[0.5]) + clf = PassiveAggressiveClassifier(class_weight=[0.5], max_iter=100) assert_raises(ValueError, clf.fit, X2, y2) - clf = PassiveAggressiveClassifier(class_weight="the larch") + clf = PassiveAggressiveClassifier(class_weight="the larch", max_iter=100) assert_raises(ValueError, clf.fit, X2, y2) @@ -219,10 +215,9 @@ def test_regressor_mse(): for data in (X, X_csr): for fit_intercept in (True, False): for average in (False, True): - reg = PassiveAggressiveRegressor(C=1.0, n_iter=50, - fit_intercept=fit_intercept, - random_state=0, - average=average) + reg = PassiveAggressiveRegressor( + C=1.0, fit_intercept=fit_intercept, + random_state=0, average=average, max_iter=5) reg.fit(data, y_bin) pred = reg.predict(data) assert_less(np.mean((pred - y_bin) ** 2), 1.7) @@ -239,10 +234,9 @@ def test_regressor_partial_fit(): for data in (X, X_csr): for average in (False, True): - reg = PassiveAggressiveRegressor(C=1.0, - fit_intercept=True, - random_state=0, - average=average) + reg = PassiveAggressiveRegressor( + C=1.0, fit_intercept=True, random_state=0, + average=average, max_iter=100) for t in range(50): reg.partial_fit(data, y_bin) pred = reg.predict(data) @@ -259,23 +253,20 @@ def test_regressor_correctness(): y_bin[y != 1] = -1 for loss in ("epsilon_insensitive", "squared_epsilon_insensitive"): - reg1 = MyPassiveAggressive(C=1.0, - loss=loss, - fit_intercept=True, - n_iter=2) + reg1 = MyPassiveAggressive( + C=1.0, loss=loss, fit_intercept=True, n_iter=2) reg1.fit(X, y_bin) for data in (X, X_csr): - reg2 = PassiveAggressiveRegressor(C=1.0, - loss=loss, - fit_intercept=True, - n_iter=2, shuffle=False) + reg2 = PassiveAggressiveRegressor( + C=1.0, tol=None, loss=loss, fit_intercept=True, max_iter=2, + shuffle=False) reg2.fit(data, y_bin) assert_array_almost_equal(reg1.w, reg2.coef_.ravel(), decimal=2) def test_regressor_undefined_methods(): - reg = PassiveAggressiveRegressor() + reg = PassiveAggressiveRegressor(max_iter=100) for meth in ("transform",): assert_raises(AttributeError, lambda x: getattr(reg, x), meth) diff --git a/sklearn/linear_model/tests/test_perceptron.py b/sklearn/linear_model/tests/test_perceptron.py index a5b97c431a..c6a46bb4df 100644 --- a/sklearn/linear_model/tests/test_perceptron.py +++ b/sklearn/linear_model/tests/test_perceptron.py @@ -2,7 +2,7 @@ import numpy as np import scipy.sparse as sp from sklearn.utils.testing import assert_array_almost_equal -from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_greater from sklearn.utils.testing import assert_raises from sklearn.utils import check_random_state @@ -45,10 +45,10 @@ class MyPerceptron(object): def test_perceptron_accuracy(): for data in (X, X_csr): - clf = Perceptron(n_iter=30, shuffle=False) + clf = Perceptron(max_iter=100, tol=None, shuffle=False) clf.fit(data, y) score = clf.score(data, y) - assert_true(score >= 0.7) + assert_greater(score, 0.7) def test_perceptron_correctness(): @@ -58,13 +58,13 @@ def test_perceptron_correctness(): clf1 = MyPerceptron(n_iter=2) clf1.fit(X, y_bin) - clf2 = Perceptron(n_iter=2, shuffle=False) + clf2 = Perceptron(max_iter=2, shuffle=False, tol=None) clf2.fit(X, y_bin) assert_array_almost_equal(clf1.w, clf2.coef_.ravel()) def test_undefined_methods(): - clf = Perceptron() + clf = Perceptron(max_iter=100) for meth in ("predict_proba", "predict_log_proba"): assert_raises(AttributeError, lambda x: getattr(clf, x), meth) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 8287ade2c2..addd235653 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -14,12 +14,17 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_false, assert_true from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises_regexp +from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_warns_message +from sklearn.utils.testing import assert_no_warnings from sklearn.utils.testing import ignore_warnings from sklearn import linear_model, datasets, metrics from sklearn.base import clone from sklearn.linear_model import SGDClassifier, SGDRegressor from sklearn.preprocessing import LabelEncoder, scale, MinMaxScaler +from sklearn.preprocessing import StandardScaler +from sklearn.exceptions import ConvergenceWarning from sklearn.linear_model import sgd_fast @@ -103,6 +108,12 @@ class CommonTest(object): def factory(self, **kwargs): if "random_state" not in kwargs: kwargs["random_state"] = 42 + + if "tol" not in kwargs: + kwargs["tol"] = None + if "max_iter" not in kwargs: + kwargs["max_iter"] = 5 + return self.factory_class(**kwargs) # a simple implementation of ASGD to use for testing @@ -143,18 +154,18 @@ class CommonTest(object): def _test_warm_start(self, X, Y, lr): # Test that explicit warm restart... - clf = self.factory(alpha=0.01, eta0=0.01, n_iter=5, shuffle=False, + clf = self.factory(alpha=0.01, eta0=0.01, shuffle=False, learning_rate=lr) clf.fit(X, Y) - clf2 = self.factory(alpha=0.001, eta0=0.01, n_iter=5, shuffle=False, + clf2 = self.factory(alpha=0.001, eta0=0.01, shuffle=False, learning_rate=lr) clf2.fit(X, Y, coef_init=clf.coef_.copy(), intercept_init=clf.intercept_.copy()) # ... and implicit warm restart are equivalent. - clf3 = self.factory(alpha=0.01, eta0=0.01, n_iter=5, shuffle=False, + clf3 = self.factory(alpha=0.01, eta0=0.01, shuffle=False, warm_start=True, learning_rate=lr) clf3.fit(X, Y) @@ -178,8 +189,7 @@ class CommonTest(object): def test_input_format(self): # Input format tests. - clf = self.factory(alpha=0.01, n_iter=5, - shuffle=False) + clf = self.factory(alpha=0.01, shuffle=False) clf.fit(X, Y) Y_ = np.array(Y)[:, np.newaxis] @@ -188,12 +198,12 @@ class CommonTest(object): def test_clone(self): # Test whether clone works ok. - clf = self.factory(alpha=0.01, n_iter=5, penalty='l1') + clf = self.factory(alpha=0.01, penalty='l1') clf = clone(clf) clf.set_params(penalty='l2') clf.fit(X, Y) - clf2 = self.factory(alpha=0.01, n_iter=5, penalty='l2') + clf2 = self.factory(alpha=0.01, penalty='l2') clf2.fit(X, Y) assert_array_equal(clf.coef_, clf2.coef_) @@ -238,10 +248,10 @@ class CommonTest(object): clf1 = self.factory(average=7, learning_rate="constant", loss='squared_loss', eta0=eta0, - alpha=alpha, n_iter=2, shuffle=False) + alpha=alpha, max_iter=2, shuffle=False) clf2 = self.factory(average=0, learning_rate="constant", loss='squared_loss', eta0=eta0, - alpha=alpha, n_iter=1, shuffle=False) + alpha=alpha, max_iter=1, shuffle=False) clf1.fit(X, Y_encode) clf2.fit(X, Y_encode) @@ -272,7 +282,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): for loss in ("hinge", "squared_hinge", "log", "modified_huber"): clf = self.factory(penalty='l2', alpha=0.01, fit_intercept=True, - loss=loss, n_iter=10, shuffle=True) + loss=loss, max_iter=10, shuffle=True) clf.fit(X, Y) # assert_almost_equal(clf.coef_[0], clf.coef_[1], decimal=7) assert_array_equal(clf.predict(T), true_result) @@ -308,9 +318,9 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): self.factory(loss="foobar") @raises(ValueError) - def test_sgd_n_iter_param(self): + def test_sgd_max_iter_param(self): # Test parameter validity check - self.factory(n_iter=-10000) + self.factory(max_iter=-10000) @raises(ValueError) def test_sgd_shuffle_param(self): @@ -353,7 +363,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): learning_rate='constant', eta0=eta, alpha=alpha, fit_intercept=True, - n_iter=1, average=True, shuffle=False) + max_iter=1, average=True, shuffle=False) # simple linear function without noise y = np.dot(X, w) @@ -379,7 +389,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): @raises(ValueError) def test_sgd_at_least_two_labels(self): # Target must have at least two labels - self.factory(alpha=0.01, n_iter=20).fit(X2, np.ones(9)) + self.factory(alpha=0.01, max_iter=20).fit(X2, np.ones(9)) def test_partial_fit_weight_class_balanced(self): # partial_fit with class_weight='balanced' not supported""" @@ -397,7 +407,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): def test_sgd_multiclass(self): # Multi-class test case - clf = self.factory(alpha=0.01, n_iter=20).fit(X2, Y2) + clf = self.factory(alpha=0.01, max_iter=20).fit(X2, Y2) assert_equal(clf.coef_.shape, (3, 2)) assert_equal(clf.intercept_.shape, (3,)) assert_equal(clf.decision_function([[0, 0]]).shape, (1, 3)) @@ -412,7 +422,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): learning_rate='constant', eta0=eta, alpha=alpha, fit_intercept=True, - n_iter=1, average=True, shuffle=False) + max_iter=1, average=True, shuffle=False) np_Y2 = np.array(Y2) clf.fit(X2, np_Y2) @@ -429,7 +439,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): def test_sgd_multiclass_with_init_coef(self): # Multi-class test case - clf = self.factory(alpha=0.01, n_iter=20) + clf = self.factory(alpha=0.01, max_iter=20) clf.fit(X2, Y2, coef_init=np.zeros((3, 2)), intercept_init=np.zeros(3)) assert_equal(clf.coef_.shape, (3, 2)) @@ -439,7 +449,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): def test_sgd_multiclass_njobs(self): # Multi-class test case with multi-core support - clf = self.factory(alpha=0.01, n_iter=20, n_jobs=2).fit(X2, Y2) + clf = self.factory(alpha=0.01, max_iter=20, n_jobs=2).fit(X2, Y2) assert_equal(clf.coef_.shape, (3, 2)) assert_equal(clf.intercept_.shape, (3,)) assert_equal(clf.decision_function([[0, 0]]).shape, (1, 3)) @@ -470,14 +480,15 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): # Hinge loss does not allow for conditional prob estimate. # We cannot use the factory here, because it defines predict_proba # anyway. - clf = SGDClassifier(loss="hinge", alpha=0.01, n_iter=10).fit(X, Y) + clf = SGDClassifier(loss="hinge", alpha=0.01, + max_iter=10, tol=None).fit(X, Y) assert_false(hasattr(clf, "predict_proba")) assert_false(hasattr(clf, "predict_log_proba")) # log and modified_huber losses can output probability estimates # binary case for loss in ["log", "modified_huber"]: - clf = self.factory(loss=loss, alpha=0.01, n_iter=10) + clf = self.factory(loss=loss, alpha=0.01, max_iter=10) clf.fit(X, Y) p = clf.predict_proba([[3, 2]]) assert_true(p[0, 1] > 0.5) @@ -490,7 +501,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): assert_true(p[0, 1] < p[0, 0]) # log loss multiclass probability estimates - clf = self.factory(loss="log", alpha=0.01, n_iter=10).fit(X2, Y2) + clf = self.factory(loss="log", alpha=0.01, max_iter=10).fit(X2, Y2) d = clf.decision_function([[.1, -.1], [.3, .2]]) p = clf.predict_proba([[.1, -.1], [.3, .2]]) @@ -513,7 +524,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): # Modified Huber multiclass probability estimates; requires a separate # test because the hard zero/one probabilities may destroy the # ordering present in decision_function output. - clf = self.factory(loss="modified_huber", alpha=0.01, n_iter=10) + clf = self.factory(loss="modified_huber", alpha=0.01, max_iter=10) clf.fit(X2, Y2) d = clf.decision_function([[3, 2]]) p = clf.predict_proba([[3, 2]]) @@ -542,7 +553,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): Y = Y4[idx] clf = self.factory(penalty='l1', alpha=.2, fit_intercept=False, - n_iter=2000, shuffle=False) + max_iter=2000, tol=None, shuffle=False) clf.fit(X, Y) assert_array_equal(clf.coef_[0, 1:-1], np.zeros((4,))) pred = clf.predict(X) @@ -566,13 +577,13 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): [1.0, 1.0], [1.0, 0.0]]) y = [1, 1, 1, -1, -1] - clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False, + clf = self.factory(alpha=0.1, max_iter=1000, fit_intercept=False, class_weight=None) clf.fit(X, y) assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1])) # we give a small weights to class 1 - clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False, + clf = self.factory(alpha=0.1, max_iter=1000, fit_intercept=False, class_weight={1: 0.001}) clf.fit(X, y) @@ -584,12 +595,12 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): # Test if equal class weights approx. equals no class weights. X = [[1, 0], [1, 0], [0, 1], [0, 1]] y = [0, 0, 1, 1] - clf = self.factory(alpha=0.1, n_iter=1000, class_weight=None) + clf = self.factory(alpha=0.1, max_iter=1000, class_weight=None) clf.fit(X, y) X = [[1, 0], [0, 1]] y = [0, 1] - clf_weighted = self.factory(alpha=0.1, n_iter=1000, + clf_weighted = self.factory(alpha=0.1, max_iter=1000, class_weight={0: 0.5, 1: 0.5}) clf_weighted.fit(X, y) @@ -599,13 +610,13 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): @raises(ValueError) def test_wrong_class_weight_label(self): # ValueError due to not existing class label. - clf = self.factory(alpha=0.1, n_iter=1000, class_weight={0: 0.5}) + clf = self.factory(alpha=0.1, max_iter=1000, class_weight={0: 0.5}) clf.fit(X, Y) @raises(ValueError) def test_wrong_class_weight_format(self): # ValueError due to wrong class_weight argument type. - clf = self.factory(alpha=0.1, n_iter=1000, class_weight=[0.5]) + clf = self.factory(alpha=0.1, max_iter=1000, class_weight=[0.5]) clf.fit(X, Y) def test_weights_multiplied(self): @@ -617,8 +628,8 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): multiplied_together[Y4 == 1] *= class_weights[1] multiplied_together[Y4 == 2] *= class_weights[2] - clf1 = self.factory(alpha=0.1, n_iter=20, class_weight=class_weights) - clf2 = self.factory(alpha=0.1, n_iter=20) + clf1 = self.factory(alpha=0.1, max_iter=20, class_weight=class_weights) + clf2 = self.factory(alpha=0.1, max_iter=20) clf1.fit(X4, Y4, sample_weight=sample_weights) clf2.fit(X4, Y4, sample_weight=multiplied_together) @@ -636,17 +647,17 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): rng.shuffle(idx) X = X[idx] y = y[idx] - clf = self.factory(alpha=0.0001, n_iter=1000, + clf = self.factory(alpha=0.0001, max_iter=1000, class_weight=None, shuffle=False).fit(X, y) - assert_almost_equal(metrics.f1_score(y, clf.predict(X), average='weighted'), 0.96, - decimal=1) + f1 = metrics.f1_score(y, clf.predict(X), average='weighted') + assert_almost_equal(f1, 0.96, decimal=1) # make the same prediction using balanced class_weight - clf_balanced = self.factory(alpha=0.0001, n_iter=1000, + clf_balanced = self.factory(alpha=0.0001, max_iter=1000, class_weight="balanced", shuffle=False).fit(X, y) - assert_almost_equal(metrics.f1_score(y, clf_balanced.predict(X), average='weighted'), 0.96, - decimal=1) + f1 = metrics.f1_score(y, clf_balanced.predict(X), average='weighted') + assert_almost_equal(f1, 0.96, decimal=1) # Make sure that in the balanced case it does not change anything # to use "balanced" @@ -660,19 +671,14 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): y_imbalanced = np.concatenate([y] + [y_0] * 10) # fit a model on the imbalanced data without class weight info - clf = self.factory(n_iter=1000, class_weight=None, shuffle=False) + clf = self.factory(max_iter=1000, class_weight=None, shuffle=False) clf.fit(X_imbalanced, y_imbalanced) y_pred = clf.predict(X) assert_less(metrics.f1_score(y, y_pred, average='weighted'), 0.96) # fit a model with balanced class_weight enabled - clf = self.factory(n_iter=1000, class_weight="balanced", shuffle=False) - clf.fit(X_imbalanced, y_imbalanced) - y_pred = clf.predict(X) - assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96) - - # fit another using a fit parameter override - clf = self.factory(n_iter=1000, class_weight="balanced", shuffle=False) + clf = self.factory(max_iter=1000, class_weight="balanced", + shuffle=False) clf.fit(X_imbalanced, y_imbalanced) y_pred = clf.predict(X) assert_greater(metrics.f1_score(y, y_pred, average='weighted'), 0.96) @@ -683,7 +689,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): [1.0, 1.0], [1.0, 0.0]]) y = [1, 1, 1, -1, -1] - clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False) + clf = self.factory(alpha=0.1, max_iter=1000, fit_intercept=False) clf.fit(X, y) assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1])) @@ -697,7 +703,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): @raises(ValueError) def test_wrong_sample_weights(self): # Test if ValueError is raised if sample_weight has wrong shape - clf = self.factory(alpha=0.1, n_iter=1000, fit_intercept=False) + clf = self.factory(alpha=0.1, max_iter=1000, fit_intercept=False) # provided sample_weight too long clf.fit(X, Y, sample_weight=np.arange(7)) @@ -765,7 +771,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): def _test_partial_fit_equal_fit(self, lr): for X_, Y_, T_ in ((X, Y, T), (X2, Y2, T2)): - clf = self.factory(alpha=0.01, eta0=0.01, n_iter=2, + clf = self.factory(alpha=0.01, eta0=0.01, max_iter=2, learning_rate=lr, shuffle=False) clf.fit(X_, Y_) y_pred = clf.decision_function(T_) @@ -815,8 +821,7 @@ class DenseSGDClassifierTestCase(unittest.TestCase, CommonTest): def test_multiple_fit(self): # Test multiple calls of fit w/ different shaped inputs. - clf = self.factory(alpha=0.01, n_iter=5, - shuffle=False) + clf = self.factory(alpha=0.01, shuffle=False) clf.fit(X, Y) assert_true(hasattr(clf, "coef_")) @@ -841,7 +846,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): def test_sgd(self): # Check that SGD gives any results. - clf = self.factory(alpha=0.1, n_iter=2, + clf = self.factory(alpha=0.1, max_iter=2, fit_intercept=False) clf.fit([[0, 0], [1, 1], [2, 2]], [0, 1, 2]) assert_equal(clf.coef_[0], clf.coef_[1]) @@ -874,7 +879,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): learning_rate='constant', eta0=eta, alpha=alpha, fit_intercept=True, - n_iter=1, average=True, shuffle=False) + max_iter=1, average=True, shuffle=False) clf.fit(X, y) average_weights, average_intercept = self.asgd(X, y, eta, alpha) @@ -901,7 +906,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): learning_rate='constant', eta0=eta, alpha=alpha, fit_intercept=True, - n_iter=1, average=True, shuffle=False) + max_iter=1, average=True, shuffle=False) clf.partial_fit(X[:int(n_samples / 2)][:], y[:int(n_samples / 2)]) clf.partial_fit(X[int(n_samples / 2):][:], y[int(n_samples / 2):]) @@ -921,7 +926,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): learning_rate='constant', eta0=eta, alpha=alpha, fit_intercept=True, - n_iter=1, average=True, shuffle=False) + max_iter=1, average=True, shuffle=False) n_samples = Y3.shape[0] @@ -943,7 +948,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): # simple linear function without noise y = 0.5 * X.ravel() - clf = self.factory(loss='squared_loss', alpha=0.1, n_iter=20, + clf = self.factory(loss='squared_loss', alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -952,7 +957,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): # simple linear function with noise y = 0.5 * X.ravel() + rng.randn(n_samples, 1).ravel() - clf = self.factory(loss='squared_loss', alpha=0.1, n_iter=20, + clf = self.factory(loss='squared_loss', alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -968,7 +973,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): y = 0.5 * X.ravel() clf = self.factory(loss='epsilon_insensitive', epsilon=0.01, - alpha=0.1, n_iter=20, + alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -978,7 +983,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): y = 0.5 * X.ravel() + rng.randn(n_samples, 1).ravel() clf = self.factory(loss='epsilon_insensitive', epsilon=0.01, - alpha=0.1, n_iter=20, + alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -993,7 +998,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): # simple linear function without noise y = 0.5 * X.ravel() - clf = self.factory(loss="huber", epsilon=0.1, alpha=0.1, n_iter=20, + clf = self.factory(loss="huber", epsilon=0.1, alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -1002,7 +1007,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): # simple linear function with noise y = 0.5 * X.ravel() + rng.randn(n_samples, 1).ravel() - clf = self.factory(loss="huber", epsilon=0.1, alpha=0.1, n_iter=20, + clf = self.factory(loss="huber", epsilon=0.1, alpha=0.1, max_iter=20, fit_intercept=False) clf.fit(X, y) score = clf.score(X, y) @@ -1025,7 +1030,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): cd = linear_model.ElasticNet(alpha=alpha, l1_ratio=l1_ratio, fit_intercept=False) cd.fit(X, y) - sgd = self.factory(penalty='elasticnet', n_iter=50, + sgd = self.factory(penalty='elasticnet', max_iter=50, alpha=alpha, l1_ratio=l1_ratio, fit_intercept=False) sgd.fit(X, y) @@ -1052,7 +1057,7 @@ class DenseSGDRegressorTestCase(unittest.TestCase, CommonTest): assert_true(id1, id2) def _test_partial_fit_equal_fit(self, lr): - clf = self.factory(alpha=0.01, n_iter=2, eta0=0.01, + clf = self.factory(alpha=0.01, max_iter=2, eta0=0.01, learning_rate=lr, shuffle=False) clf.fit(X, Y) y_pred = clf.predict(T) @@ -1095,15 +1100,19 @@ def test_l1_ratio(): random_state=1234) # test if elasticnet with l1_ratio near 1 gives same result as pure l1 - est_en = SGDClassifier(alpha=0.001, penalty='elasticnet', - l1_ratio=0.9999999999, random_state=42).fit(X, y) - est_l1 = SGDClassifier(alpha=0.001, penalty='l1', random_state=42).fit(X, y) + est_en = SGDClassifier(alpha=0.001, penalty='elasticnet', tol=None, + max_iter=6, l1_ratio=0.9999999999, + random_state=42).fit(X, y) + est_l1 = SGDClassifier(alpha=0.001, penalty='l1', max_iter=6, + random_state=42, tol=None).fit(X, y) assert_array_almost_equal(est_en.coef_, est_l1.coef_) # test if elasticnet with l1_ratio near 0 gives same result as pure l2 - est_en = SGDClassifier(alpha=0.001, penalty='elasticnet', - l1_ratio=0.0000000001, random_state=42).fit(X, y) - est_l2 = SGDClassifier(alpha=0.001, penalty='l2', random_state=42).fit(X, y) + est_en = SGDClassifier(alpha=0.001, penalty='elasticnet', tol=None, + max_iter=6, l1_ratio=0.0000000001, + random_state=42).fit(X, y) + est_l2 = SGDClassifier(alpha=0.001, penalty='l2', max_iter=6, + random_state=42, tol=None).fit(X, y) assert_array_almost_equal(est_en.coef_, est_l2.coef_) @@ -1129,7 +1138,7 @@ def test_underflow_or_overlow(): y = (np.dot(X_scaled, ground_truth) > 0.).astype(np.int32) assert_array_equal(np.unique(y), [0, 1]) - model = SGDClassifier(alpha=0.1, loss='squared_hinge', n_iter=500) + model = SGDClassifier(alpha=0.1, loss='squared_hinge', max_iter=500) # smoke test: model is stable on scaled data model.fit(X_scaled, y) @@ -1145,9 +1154,9 @@ def test_underflow_or_overlow(): def test_numerical_stability_large_gradient(): # Non regression test case for numerical stability on scaled problems # where the gradient can still explode with some losses - model = SGDClassifier(loss='squared_hinge', n_iter=10, shuffle=True, + model = SGDClassifier(loss='squared_hinge', max_iter=10, shuffle=True, penalty='elasticnet', l1_ratio=0.3, alpha=0.01, - eta0=0.001, random_state=0) + eta0=0.001, random_state=0, tol=None) with np.errstate(all='raise'): model.fit(iris.data, iris.target) assert_true(np.isfinite(model.coef_).all()) @@ -1158,12 +1167,87 @@ def test_large_regularization(): # regularization parameters for penalty in ['l2', 'l1', 'elasticnet']: model = SGDClassifier(alpha=1e5, learning_rate='constant', eta0=0.1, - n_iter=5, penalty=penalty, shuffle=False) + penalty=penalty, shuffle=False, + tol=None, max_iter=6) with np.errstate(all='raise'): model.fit(iris.data, iris.target) assert_array_almost_equal(model.coef_, np.zeros_like(model.coef_)) +def test_tol_parameter(): + # Test that the tol parameter behaves as expected + X = StandardScaler().fit_transform(iris.data) + y = iris.target == 1 + + # With tol is None, the number of iteration should be equal to max_iter + max_iter = 42 + model_0 = SGDClassifier(tol=None, random_state=0, max_iter=max_iter) + model_0.fit(X, y) + assert_equal(max_iter, model_0.n_iter_) + + # If tol is not None, the number of iteration should be less than max_iter + max_iter = 2000 + model_1 = SGDClassifier(tol=0, random_state=0, max_iter=max_iter) + model_1.fit(X, y) + assert_greater(max_iter, model_1.n_iter_) + assert_greater(model_1.n_iter_, 5) + + # A larger tol should yield a smaller number of iteration + model_2 = SGDClassifier(tol=0.1, random_state=0, max_iter=max_iter) + model_2.fit(X, y) + assert_greater(model_1.n_iter_, model_2.n_iter_) + assert_greater(model_2.n_iter_, 3) + + # Strict tolerance and small max_iter should trigger a warning + model_3 = SGDClassifier(max_iter=3, tol=1e-3, random_state=0) + model_3 = assert_warns(ConvergenceWarning, model_3.fit, X, y) + assert_equal(model_3.n_iter_, 3) + + +def test_future_and_deprecation_warnings(): + # Test that warnings are raised. Will be removed in 0.21 + + # When all default values are used + msg_future = "max_iter and tol parameters have been added in " + assert_warns_message(FutureWarning, msg_future, SGDClassifier) + + def init(max_iter=None, tol=None, n_iter=None): + SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter) + + # When n_iter is specified + msg_deprecation = "n_iter parameter is deprecated" + assert_warns_message(DeprecationWarning, msg_deprecation, init, 6, 0, 5) + + # When n_iter=None, and at least one of tol and max_iter is specified + assert_no_warnings(init, 100, None, None) + assert_no_warnings(init, None, 1e-3, None) + assert_no_warnings(init, 100, 1e-3, None) + + +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) +def test_tol_and_max_iter_default_values(): + # Test that the default values are correctly changed + est = SGDClassifier() + assert_equal(est.tol, None) + assert_equal(est.max_iter, 5) + + est = SGDClassifier(n_iter=42) + assert_equal(est.tol, None) + assert_equal(est.max_iter, 42) + + est = SGDClassifier(tol=1e-2) + assert_equal(est.tol, 1e-2) + assert_equal(est.max_iter, 1000) + + est = SGDClassifier(max_iter=42) + assert_equal(est.tol, None) + assert_equal(est.max_iter, 42) + + est = SGDClassifier(max_iter=42, tol=1e-2) + assert_equal(est.tol, 1e-2) + assert_equal(est.max_iter, 42) + + def _test_gradient_common(loss_function, cases): # Test gradient of different loss functions # cases is a list of (p, y, expected) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 1d6cf50ec1..9e6fd57ccd 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -1223,7 +1223,7 @@ def test_stochastic_gradient_loss_param(): } X = np.arange(24).reshape(6, -1) y = [0, 0, 0, 1, 1, 1] - clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'), + clf = GridSearchCV(estimator=SGDClassifier(tol=1e-3, loss='hinge'), param_grid=param_grid) # When the estimator is not fitted, `predict_proba` is not available as the @@ -1238,7 +1238,7 @@ def test_stochastic_gradient_loss_param(): param_grid = { 'loss': ['hinge'], } - clf = GridSearchCV(estimator=SGDClassifier(loss='hinge'), + clf = GridSearchCV(estimator=SGDClassifier(tol=1e-3, loss='hinge'), param_grid=param_grid) assert_false(hasattr(clf, "predict_proba")) clf.fit(X, y) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 5817c31f5f..3087c1f3bd 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -756,7 +756,8 @@ def test_learning_curve_batch_and_incremental_learning_are_equal(): n_redundant=0, n_classes=2, n_clusters_per_class=1, random_state=0) train_sizes = np.linspace(0.2, 1.0, 5) - estimator = PassiveAggressiveClassifier(n_iter=1, shuffle=False) + estimator = PassiveAggressiveClassifier(max_iter=1, tol=None, + shuffle=False) train_sizes_inc, train_scores_inc, test_scores_inc = \ learning_curve( @@ -827,7 +828,8 @@ def test_learning_curve_with_shuffle(): groups = np.array([1, 1, 1, 1, 1, 1, 3, 3, 3, 3, 3, 4, 4, 4, 4]) # Splits on these groups fail without shuffle as the first iteration # of the learning curve doesn't contain label 4 in the training set. - estimator = PassiveAggressiveClassifier(shuffle=False) + estimator = PassiveAggressiveClassifier(max_iter=5, tol=None, + shuffle=False) cv = GroupKFold(n_splits=2) train_sizes_batch, train_scores_batch, test_scores_batch = learning_curve( diff --git a/sklearn/tests/test_learning_curve.py b/sklearn/tests/test_learning_curve.py index 48cb8ce0ea..afaae84b92 100644 --- a/sklearn/tests/test_learning_curve.py +++ b/sklearn/tests/test_learning_curve.py @@ -221,7 +221,8 @@ def test_learning_curve_batch_and_incremental_learning_are_equal(): n_redundant=0, n_classes=2, n_clusters_per_class=1, random_state=0) train_sizes = np.linspace(0.2, 1.0, 5) - estimator = PassiveAggressiveClassifier(n_iter=1, shuffle=False) + estimator = PassiveAggressiveClassifier(max_iter=1, tol=None, + shuffle=False) train_sizes_inc, train_scores_inc, test_scores_inc = \ learning_curve( diff --git a/sklearn/tests/test_multiclass.py b/sklearn/tests/test_multiclass.py index 56ec67116a..7008fff41a 100644 --- a/sklearn/tests/test_multiclass.py +++ b/sklearn/tests/test_multiclass.py @@ -98,13 +98,13 @@ def test_ovr_partial_fit(): X = np.abs(np.random.randn(14, 2)) y = [1, 1, 1, 1, 2, 3, 3, 0, 0, 2, 3, 1, 2, 3] - ovr = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False, - random_state=0)) + ovr = OneVsRestClassifier(SGDClassifier(max_iter=1, tol=None, + shuffle=False, random_state=0)) ovr.partial_fit(X[:7], y[:7], np.unique(y)) ovr.partial_fit(X[7:], y[7:]) pred = ovr.predict(X) - ovr1 = OneVsRestClassifier(SGDClassifier(n_iter=1, shuffle=False, - random_state=0)) + ovr1 = OneVsRestClassifier(SGDClassifier(max_iter=1, tol=None, + shuffle=False, random_state=0)) pred1 = ovr1.fit(X, y).predict(X) assert_equal(np.mean(pred == y), np.mean(pred1 == y)) @@ -607,7 +607,8 @@ def test_ovo_ties(): # not defaulting to the smallest label X = np.array([[1, 2], [2, 1], [-2, 1], [-2, -1]]) y = np.array([2, 0, 1, 2]) - multi_clf = OneVsOneClassifier(Perceptron(shuffle=False)) + multi_clf = OneVsOneClassifier(Perceptron(shuffle=False, max_iter=4, + tol=None)) ovo_prediction = multi_clf.fit(X, y).predict(X) ovo_decision = multi_clf.decision_function(X) @@ -634,7 +635,8 @@ def test_ovo_ties2(): # cycle through labels so that each label wins once for i in range(3): y = (y_ref + i) % 3 - multi_clf = OneVsOneClassifier(Perceptron(shuffle=False)) + multi_clf = OneVsOneClassifier(Perceptron(shuffle=False, max_iter=4, + tol=None)) ovo_prediction = multi_clf.fit(X, y).predict(X) assert_equal(ovo_prediction[0], i % 3) diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index e48049360b..26647c3d19 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -50,12 +50,12 @@ def test_multi_target_regression_partial_fit(): references = np.zeros_like(y_test) half_index = 25 for n in range(3): - sgr = SGDRegressor(random_state=0) + sgr = SGDRegressor(random_state=0, max_iter=5) sgr.partial_fit(X_train[:half_index], y_train[:half_index, n]) sgr.partial_fit(X_train[half_index:], y_train[half_index:, n]) references[:, n] = sgr.predict(X_test) - sgr = MultiOutputRegressor(SGDRegressor(random_state=0)) + sgr = MultiOutputRegressor(SGDRegressor(random_state=0, max_iter=5)) sgr.partial_fit(X_train[:half_index], y_train[:half_index]) sgr.partial_fit(X_train[half_index:], y_train[half_index:]) @@ -108,12 +108,12 @@ def test_multi_target_sample_weight_partial_fit(): X = [[1, 2, 3], [4, 5, 6]] y = [[3.141, 2.718], [2.718, 3.141]] w = [2., 1.] - rgr_w = MultiOutputRegressor(SGDRegressor(random_state=0)) + rgr_w = MultiOutputRegressor(SGDRegressor(random_state=0, max_iter=5)) rgr_w.partial_fit(X, y, w) # weighted with different weights w = [2., 2.] - rgr = MultiOutputRegressor(SGDRegressor(random_state=0)) + rgr = MultiOutputRegressor(SGDRegressor(random_state=0, max_iter=5)) rgr.partial_fit(X, y, w) assert_not_equal(rgr.predict(X)[0][0], rgr_w.predict(X)[0][0]) @@ -152,7 +152,7 @@ classes = list(map(np.unique, (y1, y2, y3))) def test_multi_output_classification_partial_fit_parallelism(): - sgd_linear_clf = SGDClassifier(loss='log', random_state=1) + sgd_linear_clf = SGDClassifier(loss='log', random_state=1, max_iter=5) mor = MultiOutputClassifier(sgd_linear_clf, n_jobs=-1) mor.partial_fit(X, y, classes) est1 = mor.estimators_[0] @@ -166,7 +166,7 @@ def test_multi_output_classification_partial_fit(): # test if multi_target initializes correctly with base estimator and fit # assert predictions work as expected for predict - sgd_linear_clf = SGDClassifier(loss='log', random_state=1) + sgd_linear_clf = SGDClassifier(loss='log', random_state=1, max_iter=5) multi_target_linear = MultiOutputClassifier(sgd_linear_clf) # train the multi_target_linear and also get the predictions. @@ -193,8 +193,8 @@ def test_multi_output_classification_partial_fit(): assert_array_equal(sgd_linear_clf.predict(X), second_predictions[:, i]) -def test_multi_output_classifiation_partial_fit_no_first_classes_exception(): - sgd_linear_clf = SGDClassifier(loss='log', random_state=1) +def test_mutli_output_classifiation_partial_fit_no_first_classes_exception(): + sgd_linear_clf = SGDClassifier(loss='log', random_state=1, max_iter=5) multi_target_linear = MultiOutputClassifier(sgd_linear_clf) assert_raises_regex(ValueError, "classes must be passed on the first call " "to partial_fit.", @@ -311,14 +311,14 @@ def test_multi_output_classification_partial_fit_sample_weights(): Xw = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]] yw = [[3, 2], [2, 3], [3, 2]] w = np.asarray([2., 1., 1.]) - sgd_linear_clf = SGDClassifier(random_state=1) + sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5) clf_w = MultiOutputClassifier(sgd_linear_clf) clf_w.fit(Xw, yw, w) # unweighted, but with repeated samples X = [[1, 2, 3], [1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]] y = [[3, 2], [3, 2], [2, 3], [3, 2]] - sgd_linear_clf = SGDClassifier(random_state=1) + sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5) clf = MultiOutputClassifier(sgd_linear_clf) clf.fit(X, y) X_test = [[1.5, 2.5, 3.5]] diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index a21f095941..4760253a5a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -42,6 +42,7 @@ from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score from sklearn.random_projection import BaseRandomProjection from sklearn.feature_selection import SelectKBest from sklearn.svm.base import BaseLibSVM +from sklearn.linear_model.stochastic_gradient import BaseSGD from sklearn.pipeline import make_pipeline from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import DataConversionWarning @@ -132,7 +133,7 @@ def _yield_classifier_checks(name, classifier): yield check_decision_proba_consistency -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_supervised_y_no_nan(name, estimator_orig): # Checks that the Estimator targets are not NaN. estimator = clone(estimator_orig) @@ -284,7 +285,8 @@ def set_checking_parameters(estimator): # set parameters to speed up some estimators and # avoid deprecated behaviour params = estimator.get_params() - if ("n_iter" in params and estimator.__class__.__name__ != "TSNE"): + if ("n_iter" in params and estimator.__class__.__name__ != "TSNE" + and not isinstance(estimator, BaseSGD)): estimator.set_params(n_iter=5) if "max_iter" in params: warnings.simplefilter("ignore", ConvergenceWarning) @@ -363,14 +365,14 @@ def check_estimator_sparse_data(name, estimator_orig): for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']: X = X_csr.asformat(sparse_format) # catch deprecation warnings - with ignore_warnings(category=DeprecationWarning): + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): if name in ['Scaler', 'StandardScaler']: estimator = clone(estimator).set_params(with_mean=False) else: estimator = clone(estimator) # fit and predict try: - with ignore_warnings(category=DeprecationWarning): + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): estimator.fit(X, y) if hasattr(estimator, "predict"): pred = estimator.predict(X) @@ -392,7 +394,7 @@ def check_estimator_sparse_data(name, estimator_orig): raise -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_sample_weights_pandas_series(name, estimator_orig): # check that estimators will accept a 'sample_weight' parameter of # type pandas.Series in the 'fit' function. @@ -414,7 +416,7 @@ def check_sample_weights_pandas_series(name, estimator_orig): "input of type pandas.Series to class weight.") -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_sample_weights_list(name, estimator_orig): # check that estimators will accept a 'sample_weight' parameter of # type list in the 'fit' function. @@ -429,7 +431,7 @@ def check_sample_weights_list(name, estimator_orig): estimator.fit(X, y, sample_weight=sample_weight) -@ignore_warnings(category=(DeprecationWarning, UserWarning)) +@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning)) def check_dtype_object(name, estimator_orig): # check that estimators treat dtype object as numeric if possible rng = np.random.RandomState(0) @@ -498,7 +500,7 @@ def is_public_parameter(attr): return not (attr.startswith('_') or attr.endswith('_')) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_dont_overwrite_parameters(name, estimator_orig): # check that fit method only changes or sets private attributes if hasattr(estimator_orig.__init__, "deprecated_original"): @@ -548,7 +550,7 @@ def check_dont_overwrite_parameters(name, estimator_orig): ' %s changed' % ', '.join(attrs_changed_by_fit))) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_fit2d_predict1d(name, estimator_orig): # check by fitting a 2d array and predicting with a 1d array rnd = np.random.RandomState(0) @@ -658,7 +660,7 @@ def check_fit1d_1sample(name, estimator_orig): pass -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_transformer_general(name, transformer): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) @@ -668,7 +670,7 @@ def check_transformer_general(name, transformer): _check_transformer(name, transformer, X.tolist(), y.tolist()) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_transformer_data_not_an_array(name, transformer): X, y = make_blobs(n_samples=30, centers=[[0, 0, 0], [1, 1, 1]], random_state=0, n_features=2, cluster_std=0.1) @@ -681,12 +683,11 @@ def check_transformer_data_not_an_array(name, transformer): _check_transformer(name, transformer, this_X, this_y) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_transformers_unfitted(name, transformer): X, y = _boston_subset() transformer = clone(transformer) - assert_raises((AttributeError, ValueError), transformer.transform, X) @@ -844,7 +845,7 @@ def check_estimators_dtypes(name, estimator_orig): getattr(estimator, method)(X_train) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_estimators_empty_data_messages(name, estimator_orig): e = clone(estimator_orig) set_random_state(e, 1) @@ -882,7 +883,7 @@ def check_estimators_nan_inf(name, estimator_orig): " transform.") for X_train in [X_train_nan, X_train_inf]: # catch deprecation warnings - with ignore_warnings(category=DeprecationWarning): + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): estimator = clone(estimator_orig) set_random_state(estimator, 1) # try to fit @@ -969,7 +970,7 @@ def check_estimators_pickle(name, estimator_orig): assert_allclose_dense_sparse(result[method], unpickled_result) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_estimators_partial_fit_n_features(name, estimator_orig): # check if number of features changes between calls to partial_fit. if not hasattr(estimator_orig, 'partial_fit'): @@ -990,7 +991,7 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): assert_raises(ValueError, estimator.partial_fit, X[:, :-1], y) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_clustering(name, clusterer_orig): clusterer = clone(clusterer_orig) X, y = make_blobs(n_samples=50, random_state=1) @@ -1050,7 +1051,7 @@ def check_classifiers_one_label(name, classifier_orig): X_test = rnd.uniform(size=(10, 3)) y = np.ones(10) # catch deprecation warnings - with ignore_warnings(category=DeprecationWarning): + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): classifier = clone(classifier_orig) # try to fit try: @@ -1146,7 +1147,7 @@ def check_classifiers_train(name, classifier_orig): assert_array_equal(np.argsort(y_log_prob), np.argsort(y_prob)) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_estimators_fit_returns_self(name, estimator_orig): """Check if self is returned when calling fit""" X, y = make_blobs(random_state=0, n_samples=9, n_features=4) @@ -1193,7 +1194,7 @@ def check_estimators_unfitted(name, estimator_orig): est.predict_log_proba, X) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_supervised_y_2d(name, estimator_orig): if "MultiTask" in name: # These only work on 2d, so this test makes no sense @@ -1225,7 +1226,7 @@ def check_supervised_y_2d(name, estimator_orig): assert_allclose(y_pred.ravel(), y_pred_2d.ravel()) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_classifiers_classes(name, classifier_orig): X, y = make_blobs(n_samples=30, random_state=0, cluster_std=0.1) X, y = shuffle(X, y, random_state=7) @@ -1259,7 +1260,7 @@ def check_classifiers_classes(name, classifier_orig): (classifier, classes, classifier.classes_)) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_regressors_int(name, regressor_orig): X, _ = _boston_subset() X = X[:50] @@ -1287,7 +1288,7 @@ def check_regressors_int(name, regressor_orig): assert_allclose(pred1, pred2, atol=1e-2, err_msg=name) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_regressors_train(name, regressor_orig): X, y = _boston_subset() y = StandardScaler().fit_transform(y.reshape(-1, 1)) # X is already scaled @@ -1346,7 +1347,7 @@ def check_regressors_no_decision_function(name, regressor_orig): assert_warns_message(DeprecationWarning, msg, func, X) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_class_weight_classifiers(name, classifier_orig): if name == "NuSVC": # the sparse version has a parameter that doesn't do anything @@ -1372,6 +1373,8 @@ def check_class_weight_classifiers(name, classifier_orig): class_weight=class_weight) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) + if hasattr(classifier, "max_iter"): + classifier.set_params(max_iter=1000) if hasattr(classifier, "min_weight_fraction_leaf"): classifier.set_params(min_weight_fraction_leaf=0.01) @@ -1383,12 +1386,14 @@ def check_class_weight_classifiers(name, classifier_orig): assert_greater(np.mean(y_pred == 0), 0.87) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, y_train, X_test, y_test, weights): classifier = clone(classifier_orig) if hasattr(classifier, "n_iter"): classifier.set_params(n_iter=100) + if hasattr(classifier, "max_iter"): + classifier.set_params(max_iter=1000) set_random_state(classifier) classifier.fit(X_train, y_train) @@ -1401,7 +1406,7 @@ def check_class_weight_balanced_classifiers(name, classifier_orig, X_train, f1_score(y_test, y_pred, average='weighted')) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_class_weight_balanced_linear_classifier(name, Classifier): """Test class weights with non-contiguous class labels.""" # this is run on classes, not instances, though this should be changed @@ -1410,10 +1415,13 @@ def check_class_weight_balanced_linear_classifier(name, Classifier): y = np.array([1, 1, 1, -1, -1]) classifier = Classifier() + if hasattr(classifier, "n_iter"): # This is a very small dataset, default n_iter are likely to prevent # convergence classifier.set_params(n_iter=1000) + if hasattr(classifier, "max_iter"): + classifier.set_params(max_iter=1000) set_random_state(classifier) # Let the model compute the class frequencies @@ -1432,7 +1440,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier): assert_allclose(coef_balanced, coef_manual) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_estimators_overwrite_params(name, estimator_orig): X, y = make_blobs(random_state=0, n_samples=9) # some want non-negative input @@ -1466,7 +1474,7 @@ def check_estimators_overwrite_params(name, estimator_orig): % (name, param_name, original_value, new_value)) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_no_fit_attributes_set_in_init(name, Estimator): """Check that Estimator.__init__ doesn't set trailing-_ attributes.""" # this check works on classes, not instances @@ -1485,7 +1493,7 @@ def check_no_fit_attributes_set_in_init(name, Estimator): 'was found in estimator {}'.format(attr, name)) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_sparsify_coefficients(name, estimator_orig): X = np.array([[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, -2], [2, 2], [-2, -2]]) @@ -1523,9 +1531,8 @@ def check_regressor_data_not_an_array(name, estimator_orig): check_estimators_data_not_an_array(name, estimator_orig, X, y) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_estimators_data_not_an_array(name, estimator_orig, X, y): - if name in CROSS_DECOMPOSITION: raise SkipTest # separate estimators to control random seeds @@ -1550,7 +1557,7 @@ def check_parameters_default_constructible(name, Estimator): classifier = LinearDiscriminantAnalysis() # test default-constructibility # get rid of deprecation warnings - with ignore_warnings(category=DeprecationWarning): + with ignore_warnings(category=(DeprecationWarning, FutureWarning)): if name in META_ESTIMATORS: estimator = Estimator(classifier) else: @@ -1601,11 +1608,16 @@ def check_parameters_default_constructible(name, Estimator): assert_true(init_param.default is None) continue + if (issubclass(Estimator, BaseSGD) and + init_param.name in ['tol', 'max_iter']): + # To remove in 0.21, when they get their future default values + continue + param_value = params[init_param.name] if isinstance(param_value, np.ndarray): assert_array_equal(param_value, init_param.default) else: - assert_equal(param_value, init_param.default) + assert_equal(param_value, init_param.default, init_param.name) def multioutput_estimator_convert_y_2d(estimator, y): @@ -1616,7 +1628,7 @@ def multioutput_estimator_convert_y_2d(estimator, y): return y -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_non_transformer_estimators_n_iter(name, estimator_orig): # Test that estimators that are not transformers with a parameter # max_iter, return the attribute of n_iter_ at least 1. @@ -1655,7 +1667,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): assert_greater_equal(estimator.n_iter_, 1) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_transformer_n_iter(name, estimator_orig): # Test that transformers with a parameter max_iter, return the # attribute of n_iter_ at least 1. @@ -1681,7 +1693,7 @@ def check_transformer_n_iter(name, estimator_orig): assert_greater_equal(estimator.n_iter_, 1) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_get_params_invariance(name, estimator_orig): # Checks if get_params(deep=False) is a subset of get_params(deep=True) class T(BaseEstimator): @@ -1706,7 +1718,7 @@ def check_get_params_invariance(name, estimator_orig): shallow_params.items())) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_classifiers_regression_target(name, estimator_orig): # Check if classifier throws an exception when fed regression targets @@ -1717,7 +1729,7 @@ def check_classifiers_regression_target(name, estimator_orig): assert_raises_regex(ValueError, msg, e.fit, X, y) -@ignore_warnings(category=DeprecationWarning) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) def check_decision_proba_consistency(name, estimator_orig): # Check whether an estimator having both decision_function and # predict_proba methods has outputs with perfect rank correlation. diff --git a/sklearn/utils/weight_vector.pyx b/sklearn/utils/weight_vector.pyx index 8cc8d01357..bb4e852221 100644 --- a/sklearn/utils/weight_vector.pyx +++ b/sklearn/utils/weight_vector.pyx @@ -20,7 +20,6 @@ cdef extern from "cblas.h": void daxpy "cblas_daxpy" (int, double, const double*, int, double*, int) nogil - np.import_array() -- GitLab