Skip to content

Commit ac9bfdb

Browse files
committed
MAINT make export_graphviz more exception-safe
1 parent 494a91b commit ac9bfdb

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

sklearn/tree/export.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,22 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
115115
out_file.write('%d -> %d ;\n' % (parent, node_id))
116116

117117
own_file = False
118-
if isinstance(out_file, six.string_types):
119-
if six.PY3:
120-
out_file = open(out_file, "w", encoding="utf-8")
121-
else:
122-
out_file = open(out_file, "wb")
123-
own_file = True
118+
try:
119+
if isinstance(out_file, six.string_types):
120+
if six.PY3:
121+
out_file = open(out_file, "w", encoding="utf-8")
122+
else:
123+
out_file = open(out_file, "wb")
124+
own_file = True
124125

125-
out_file.write("digraph Tree {\n")
126+
out_file.write("digraph Tree {\n")
126127

127-
if isinstance(decision_tree, _tree.Tree):
128-
recurse(decision_tree, 0, criterion="impurity")
129-
else:
130-
recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
131-
out_file.write("}")
128+
if isinstance(decision_tree, _tree.Tree):
129+
recurse(decision_tree, 0, criterion="impurity")
130+
else:
131+
recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
132+
out_file.write("}")
132133

133-
if own_file:
134-
out_file.close()
134+
finally:
135+
if own_file:
136+
out_file.close()

0 commit comments

Comments
 (0)