Skip to content
Snippets Groups Projects
Commit ce861b41 authored by trevorstephens's avatar trevorstephens Committed by Andreas Mueller
Browse files

fix rounding, adjust tests for 32 bit export_graphviz

parent 6f69e83d
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,6 @@ from . import _criterion
from . import _tree
def _color_brew(n):
"""Generate n colors with equally spaced hues.
......@@ -153,13 +152,14 @@ def export_graphviz(decision_tree, out_file="tree.dot", max_depth=None,
# Classification tree
color = list(colors['rgb'][np.argmax(value)])
sorted_values = sorted(value, reverse=True)
alpha = int(255 * (sorted_values[0] - sorted_values[1]) /
(1 - sorted_values[1]))
alpha = int(np.round(255 * (sorted_values[0] - sorted_values[1]) /
(1 - sorted_values[1]), 0))
else:
# Regression tree or multi-output
color = list(colors['rgb'][0])
alpha = int(255 * ((value - colors['bounds'][0]) /
(colors['bounds'][1] - colors['bounds'][0])))
alpha = int(np.round(255 * ((value - colors['bounds'][0]) /
(colors['bounds'][1] -
colors['bounds'][0])), 0))
# Return html color code in #RRGGBBAA format
color.append(alpha)
......
......@@ -16,7 +16,7 @@ from sklearn.utils.testing import assert_in
# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y = [-1, -1, -1, 1, 1, 1]
y2 = [[-1, 1], [-1, 2], [-1, 3], [1, 1], [1, 2], [1, 3]]
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
w = [1, 1, 1, .5, .5, .5]
......@@ -154,29 +154,22 @@ def test_graphviz_toy():
'node [shape=box, style="filled", color="black"] ;\n' \
'0 [label="X[0] <= 0.0\\nsamples = 6\\n' \
'value = [[3.0, 1.5, 0.0]\\n' \
'[1.5, 1.5, 1.5]]", fillcolor="#e5813900"] ;\n' \
'1 [label="X[1] <= -1.5\\nsamples = 3\\n' \
'value = [[3, 0, 0]\\n[1, 1, 1]]", ' \
'fillcolor="#e5813965"] ;\n' \
'[3.0, 1.0, 0.5]]", fillcolor="#e5813900"] ;\n' \
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n' \
'[3, 0, 0]]", fillcolor="#e58139ff"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=45, ' \
'headlabel="True"] ;\n' \
'2 [label="samples = 1\\nvalue = [[1, 0, 0]\\n' \
'[0, 0, 1]]", fillcolor="#e58139ff"] ;\n' \
'1 -> 2 ;\n' \
'3 [label="samples = 2\\nvalue = [[2, 0, 0]\\n' \
'[1, 1, 0]]", fillcolor="#e581398c"] ;\n' \
'1 -> 3 ;\n' \
'4 [label="X[0] <= 1.5\\nsamples = 3\\n' \
'value = [[0.0, 1.5, 0.0]\\n[0.5, 0.5, 0.5]]", ' \
'fillcolor="#e5813965"] ;\n' \
'0 -> 4 [labeldistance=2.5, labelangle=-45, ' \
'2 [label="X[0] <= 1.5\\nsamples = 3\\n' \
'value = [[0.0, 1.5, 0.0]\\n' \
'[0.0, 1.0, 0.5]]", fillcolor="#e5813986"] ;\n' \
'0 -> 2 [labeldistance=2.5, labelangle=-45, ' \
'headlabel="False"] ;\n' \
'5 [label="samples = 2\\nvalue = [[0.0, 1.0, 0.0]\\n' \
'[0.5, 0.5, 0.0]]", fillcolor="#e581398c"] ;\n' \
'4 -> 5 ;\n' \
'6 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n' \
'[0, 1, 0]]", fillcolor="#e58139ff"] ;\n' \
'2 -> 3 ;\n' \
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n' \
'[0.0, 0.0, 0.5]]", fillcolor="#e58139ff"] ;\n' \
'4 -> 6 ;\n' \
'2 -> 4 ;\n' \
'}'
assert_equal(contents1, contents2)
......@@ -199,7 +192,7 @@ def test_graphviz_toy():
'edge [fontname=helvetica] ;\n' \
'rankdir=LR ;\n' \
'0 [label="X[0] <= 0.0\\nmse = 1.0\\nsamples = 6\\n' \
'value = 0.0", fillcolor="#e581397f"] ;\n' \
'value = 0.0", fillcolor="#e5813980"] ;\n' \
'1 [label="mse = 0.0\\nsamples = 3\\nvalue = -1.0", ' \
'fillcolor="#e5813900"] ;\n' \
'0 -> 1 [labeldistance=2.5, labelangle=-45, ' \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment