From 2beefbc4b331149a8bf0dba5726e017651b845dc Mon Sep 17 00:00:00 2001
From: Hanmin Qin <qinhanmin2005@sina.com>
Date: Sun, 23 Apr 2017 09:34:08 +0800
Subject: [PATCH] [MRG] Improve the error message of export_graphviz if a
 not-fitted decision tree is provided (#8776)

---
 sklearn/tree/export.py            | 2 ++
 sklearn/tree/tests/test_export.py | 6 ++++++
 2 files changed, 8 insertions(+)

diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py
index 43e8aa11b9..db89ae25d9 100644
--- a/sklearn/tree/export.py
+++ b/sklearn/tree/export.py
@@ -14,6 +14,7 @@ import numpy as np
 import warnings
 
 from ..externals import six
+from ..utils.validation import check_is_fitted
 
 from . import _criterion
 from . import _tree
@@ -377,6 +378,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
                 # Add edge to parent
                 out_file.write('%d -> %d ;\n' % (parent, node_id))
 
+    check_is_fitted(decision_tree, 'tree_')
     own_file = False
     return_string = False
     try:
diff --git a/sklearn/tree/tests/test_export.py b/sklearn/tree/tests/test_export.py
index 1379a7703f..89d9cd7370 100644
--- a/sklearn/tree/tests/test_export.py
+++ b/sklearn/tree/tests/test_export.py
@@ -9,6 +9,7 @@ from sklearn.ensemble import GradientBoostingClassifier
 from sklearn.tree import export_graphviz
 from sklearn.externals.six import StringIO
 from sklearn.utils.testing import assert_in, assert_equal, assert_raises
+from sklearn.exceptions import NotFittedError
 
 # toy sample
 X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
@@ -210,6 +211,11 @@ def test_graphviz_toy():
 def test_graphviz_errors():
     # Check for errors of export_graphviz
     clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
+
+    # Check not-fitted decision tree error
+    out = StringIO()
+    assert_raises(NotFittedError, export_graphviz, clf, out)
+
     clf.fit(X, y)
 
     # Check feature_names error
-- 
GitLab