From de29f3f22db6e017aef9dc77935d8ef43d2d7b44 Mon Sep 17 00:00:00 2001 From: "Nicholas Nadeau, P.Eng., AVS" <nnadeau@users.noreply.github.com> Date: Sun, 29 Oct 2017 12:16:26 -0400 Subject: [PATCH] [MRG+1] `MLPRegressor` quits fitting too soon due to `self._no_improvement_count` (#9457) --- doc/modules/neural_networks_supervised.rst | 26 ++++---- doc/whats_new/v0.20.rst | 20 +++++++ .../neural_network/multilayer_perceptron.py | 59 +++++++++++++------ sklearn/neural_network/tests/test_mlp.py | 45 ++++++++++++++ 4 files changed, 119 insertions(+), 31 deletions(-) diff --git a/doc/modules/neural_networks_supervised.rst b/doc/modules/neural_networks_supervised.rst index 292ed903ee..9e5927349b 100644 --- a/doc/modules/neural_networks_supervised.rst +++ b/doc/modules/neural_networks_supervised.rst @@ -91,12 +91,13 @@ training samples:: ... >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto', - beta_1=0.9, beta_2=0.999, early_stopping=False, - epsilon=1e-08, hidden_layer_sizes=(5, 2), learning_rate='constant', - learning_rate_init=0.001, max_iter=200, momentum=0.9, - nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True, - solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False, - warm_start=False) + beta_1=0.9, beta_2=0.999, early_stopping=False, + epsilon=1e-08, hidden_layer_sizes=(5, 2), + learning_rate='constant', learning_rate_init=0.001, + max_iter=200, momentum=0.9, n_iter_no_change=10, + nesterovs_momentum=True, power_t=0.5, random_state=1, + shuffle=True, solver='lbfgs', tol=0.0001, + validation_fraction=0.1, verbose=False, warm_start=False) After fitting (training), the model can predict labels for new samples:: @@ -139,12 +140,13 @@ indices where the value is `1` represents the assigned classes of that sample:: ... >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE MLPClassifier(activation='relu', alpha=1e-05, batch_size='auto', - beta_1=0.9, beta_2=0.999, early_stopping=False, - epsilon=1e-08, hidden_layer_sizes=(15,), learning_rate='constant', - learning_rate_init=0.001, max_iter=200, momentum=0.9, - nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True, - solver='lbfgs', tol=0.0001, validation_fraction=0.1, verbose=False, - warm_start=False) + beta_1=0.9, beta_2=0.999, early_stopping=False, + epsilon=1e-08, hidden_layer_sizes=(15,), + learning_rate='constant', learning_rate_init=0.001, + max_iter=200, momentum=0.9, n_iter_no_change=10, + nesterovs_momentum=True, power_t=0.5, random_state=1, + shuffle=True, solver='lbfgs', tol=0.0001, + validation_fraction=0.1, verbose=False, warm_start=False) >>> clf.predict([[1., 2.]]) array([[1, 1]]) >>> clf.predict([[0., 0.]]) diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 5af76499bc..0897f331eb 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -18,6 +18,9 @@ random sampling procedures. - :class:`decomposition.IncrementalPCA` in Python 2 (bug fix) - :class:`isotonic.IsotonicRegression` (bug fix) - :class:`metrics.roc_auc_score` (bug fix) +- :class:`neural_network.BaseMultilayerPerceptron` (bug fix) +- :class:`neural_network.MLPRegressor` (bug fix) +- :class:`neural_network.MLPClassifier` (bug fix) Details are listed in the changelog below. @@ -65,6 +68,13 @@ Classifiers and regressors :class:`sklearn.naive_bayes.GaussianNB` to give a precise control over variances calculation. :issue:`9681` by :user:`Dmitry Mottl <Mottl>`. +- Add `n_iter_no_change` parameter in + :class:`neural_network.BaseMultilayerPerceptron`, + :class:`neural_network.MLPRegressor`, and + :class:`neural_network.MLPClassifier` to give control over + maximum number of epochs to not meet ``tol`` improvement. + :issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`. + - A parameter ``check_inverse`` was added to :class:`FunctionTransformer` to ensure that ``func`` and ``inverse_func`` are the inverse of each other. @@ -96,6 +106,16 @@ Classifiers and regressors identical X values. :issue:`9432` by :user:`Dallas Card <dallascard>` +- Fixed a bug in :class:`neural_network.BaseMultilayerPerceptron`, + :class:`neural_network.MLPRegressor`, and + :class:`neural_network.MLPClassifier` with new ``n_iter_no_change`` + parameter now at 10 from previously hardcoded 2. + :issue:`9456` by :user:`Nicholas Nadeau <nnadeau>`. + +- Fixed a bug in :class:`neural_network.MLPRegressor` where fitting + quit unexpectedly early due to local minima or fluctuations. + :issue:`9456` by :user:`Nicholas Nadeau <nnadeau>` + - Fixed a bug in :class:`naive_bayes.GaussianNB` which incorrectly raised error for prior list which summed to 1. :issue:`10005` by :user:`Gaurav Dhingra <gxyd>`. diff --git a/sklearn/neural_network/multilayer_perceptron.py b/sklearn/neural_network/multilayer_perceptron.py index ae6df22c2f..c693c11614 100644 --- a/sklearn/neural_network/multilayer_perceptron.py +++ b/sklearn/neural_network/multilayer_perceptron.py @@ -51,7 +51,8 @@ class BaseMultilayerPerceptron(six.with_metaclass(ABCMeta, BaseEstimator)): alpha, batch_size, learning_rate, learning_rate_init, power_t, max_iter, loss, shuffle, random_state, tol, verbose, warm_start, momentum, nesterovs_momentum, early_stopping, - validation_fraction, beta_1, beta_2, epsilon): + validation_fraction, beta_1, beta_2, epsilon, + n_iter_no_change): self.activation = activation self.solver = solver self.alpha = alpha @@ -74,6 +75,7 @@ class BaseMultilayerPerceptron(six.with_metaclass(ABCMeta, BaseEstimator)): self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon + self.n_iter_no_change = n_iter_no_change def _unpack(self, packed_parameters): """Extract the coefficients and intercepts from packed_parameters.""" @@ -415,6 +417,9 @@ class BaseMultilayerPerceptron(six.with_metaclass(ABCMeta, BaseEstimator)): self.beta_2) if self.epsilon <= 0.0: raise ValueError("epsilon must be > 0, got %s." % self.epsilon) + if self.n_iter_no_change <= 0: + raise ValueError("n_iter_no_change must be > 0, got %s." + % self.n_iter_no_change) # raise ValueError if not registered supported_activations = ('identity', 'logistic', 'tanh', 'relu') @@ -537,15 +542,17 @@ class BaseMultilayerPerceptron(six.with_metaclass(ABCMeta, BaseEstimator)): # for learning rate that needs to be updated at iteration end self._optimizer.iteration_ends(self.t_) - if self._no_improvement_count > 2: - # not better than last two iterations by tol. + if self._no_improvement_count > self.n_iter_no_change: + # not better than last `n_iter_no_change` iterations by tol # stop or decrease learning rate if early_stopping: msg = ("Validation score did not improve more than " - "tol=%f for two consecutive epochs." % self.tol) + "tol=%f for %d consecutive epochs." % ( + self.tol, self.n_iter_no_change)) else: msg = ("Training loss did not improve more than tol=%f" - " for two consecutive epochs." % self.tol) + " for %d consecutive epochs." % ( + self.tol, self.n_iter_no_change)) is_stopping = self._optimizer.trigger_stopping( msg, self.verbose) @@ -780,9 +787,9 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): tol : float, optional, default 1e-4 Tolerance for the optimization. When the loss or score is not improving - by at least tol for two consecutive iterations, unless `learning_rate` - is set to 'adaptive', convergence is considered to be reached and - training stops. + by at least ``tol`` for ``n_iter_no_change`` consecutive iterations, + unless ``learning_rate`` is set to 'adaptive', convergence is + considered to be reached and training stops. verbose : bool, optional, default False Whether to print progress messages to stdout. @@ -804,8 +811,8 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): Whether to use early stopping to terminate training when validation score is not improving. If set to true, it will automatically set aside 10% of training data as validation and terminate training when - validation score is not improving by at least tol for two consecutive - epochs. + validation score is not improving by at least tol for + ``n_iter_no_change`` consecutive epochs. Only effective when solver='sgd' or 'adam' validation_fraction : float, optional, default 0.1 @@ -824,6 +831,12 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): epsilon : float, optional, default 1e-8 Value for numerical stability in adam. Only used when solver='adam' + n_iter_no_change : int, optional, default 10 + Maximum number of epochs to not meet ``tol`` improvement. + Only effective when solver='sgd' or 'adam' + + .. versionadded:: 0.20 + Attributes ---------- classes_ : array or list of array of shape (n_classes,) @@ -890,7 +903,7 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, - epsilon=1e-8): + epsilon=1e-8, n_iter_no_change=10): sup = super(MLPClassifier, self) sup.__init__(hidden_layer_sizes=hidden_layer_sizes, @@ -903,7 +916,8 @@ class MLPClassifier(BaseMultilayerPerceptron, ClassifierMixin): nesterovs_momentum=nesterovs_momentum, early_stopping=early_stopping, validation_fraction=validation_fraction, - beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) + beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, + n_iter_no_change=n_iter_no_change) def _validate_input(self, X, y, incremental): X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], @@ -1157,9 +1171,9 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): tol : float, optional, default 1e-4 Tolerance for the optimization. When the loss or score is not improving - by at least tol for two consecutive iterations, unless `learning_rate` - is set to 'adaptive', convergence is considered to be reached and - training stops. + by at least ``tol`` for ``n_iter_no_change`` consecutive iterations, + unless ``learning_rate`` is set to 'adaptive', convergence is + considered to be reached and training stops. verbose : bool, optional, default False Whether to print progress messages to stdout. @@ -1181,8 +1195,8 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): Whether to use early stopping to terminate training when validation score is not improving. If set to true, it will automatically set aside 10% of training data as validation and terminate training when - validation score is not improving by at least tol for two consecutive - epochs. + validation score is not improving by at least ``tol`` for + ``n_iter_no_change`` consecutive epochs. Only effective when solver='sgd' or 'adam' validation_fraction : float, optional, default 0.1 @@ -1201,6 +1215,12 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): epsilon : float, optional, default 1e-8 Value for numerical stability in adam. Only used when solver='adam' + n_iter_no_change : int, optional, default 10 + Maximum number of epochs to not meet ``tol`` improvement. + Only effective when solver='sgd' or 'adam' + + .. versionadded:: 0.20 + Attributes ---------- loss_ : float @@ -1265,7 +1285,7 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): verbose=False, warm_start=False, momentum=0.9, nesterovs_momentum=True, early_stopping=False, validation_fraction=0.1, beta_1=0.9, beta_2=0.999, - epsilon=1e-8): + epsilon=1e-8, n_iter_no_change=10): sup = super(MLPRegressor, self) sup.__init__(hidden_layer_sizes=hidden_layer_sizes, @@ -1278,7 +1298,8 @@ class MLPRegressor(BaseMultilayerPerceptron, RegressorMixin): nesterovs_momentum=nesterovs_momentum, early_stopping=early_stopping, validation_fraction=validation_fraction, - beta_1=beta_1, beta_2=beta_2, epsilon=epsilon) + beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, + n_iter_no_change=n_iter_no_change) def predict(self, X): """Predict using the multi-layer perceptron model. diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 9c42b7c930..b0d5ab587a 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -420,6 +420,7 @@ def test_params_errors(): assert_raises(ValueError, clf(beta_2=1).fit, X, y) assert_raises(ValueError, clf(beta_2=-0.5).fit, X, y) assert_raises(ValueError, clf(epsilon=-0.5).fit, X, y) + assert_raises(ValueError, clf(n_iter_no_change=-1).fit, X, y) assert_raises(ValueError, clf(solver='hadoken').fit, X, y) assert_raises(ValueError, clf(learning_rate='converge').fit, X, y) @@ -588,3 +589,47 @@ def test_warm_start(): 'classes as in the previous call to fit.' ' Previously got [0 1 2], `y` has %s' % np.unique(y_i)) assert_raise_message(ValueError, message, clf.fit, X, y_i) + + +def test_n_iter_no_change(): + # test n_iter_no_change using binary data set + # the classifying fitting process is not prone to loss curve fluctuations + X = X_digits_binary[:100] + y = y_digits_binary[:100] + tol = 0.01 + max_iter = 3000 + + # test multiple n_iter_no_change + for n_iter_no_change in [2, 5, 10, 50, 100]: + clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd', + n_iter_no_change=n_iter_no_change) + clf.fit(X, y) + + # validate n_iter_no_change + assert_equal(clf._no_improvement_count, n_iter_no_change + 1) + assert_greater(max_iter, clf.n_iter_) + + +@ignore_warnings(category=ConvergenceWarning) +def test_n_iter_no_change_inf(): + # test n_iter_no_change using binary data set + # the fitting process should go to max_iter iterations + X = X_digits_binary[:100] + y = y_digits_binary[:100] + + # set a ridiculous tolerance + # this should always trigger _update_no_improvement_count() + tol = 1e9 + + # fit + n_iter_no_change = np.inf + max_iter = 3000 + clf = MLPClassifier(tol=tol, max_iter=max_iter, solver='sgd', + n_iter_no_change=n_iter_no_change) + clf.fit(X, y) + + # validate n_iter_no_change doesn't cause early stopping + assert_equal(clf.n_iter_, max_iter) + + # validate _update_no_improvement_count() was always triggered + assert_equal(clf._no_improvement_count, clf.n_iter_ - 1) -- GitLab