From 2beefbc4b331149a8bf0dba5726e017651b845dc Mon Sep 17 00:00:00 2001 From: Hanmin Qin <qinhanmin2005@sina.com> Date: Sun, 23 Apr 2017 09:34:08 +0800 Subject: [PATCH] [MRG] Improve the error message of export_graphviz if a not-fitted decision tree is provided (#8776) --- sklearn/tree/export.py | 2 ++ sklearn/tree/tests/test_export.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 43e8aa11b9..db89ae25d9 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -14,6 +14,7 @@ import numpy as np import warnings from ..externals import six +from ..utils.validation import check_is_fitted from . import _criterion from . import _tree @@ -377,6 +378,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, # Add edge to parent out_file.write('%d -> %d ;\n' % (parent, node_id)) + check_is_fitted(decision_tree, 'tree_') own_file = False return_string = False try: diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py index 1379a7703f..89d9cd7370 100644 --- a/sklearn/tree/tests/test_export.py +++ b/sklearn/tree/tests/test_export.py @@ -9,6 +9,7 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.tree import export_graphviz from sklearn.externals.six import StringIO from sklearn.utils.testing import assert_in, assert_equal, assert_raises +from sklearn.exceptions import NotFittedError # toy sample X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] @@ -210,6 +211,11 @@ def test_graphviz_toy(): def test_graphviz_errors(): # Check for errors of export_graphviz clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2) + + # Check not-fitted decision tree error + out = StringIO() + assert_raises(NotFittedError, export_graphviz, clf, out) + clf.fit(X, y) # Check feature_names error -- GitLab