使用export graphviz创建决策树图后如何更改颜色

11

我正在使用scikit的回归树函数和graphviz生成一些决策树的好看、易于解释的可视化图形:

dot_data = tree.export_graphviz(Run.reg, out_file=None, 
                         feature_names=Xvar,  
                         filled=True, rounded=True,  
                         special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('CART.png')
graph.write_svg("CART.svg")

输入图像描述

这个代码运行得很好,但我想如果可能的话改变颜色方案?该图表示CO2通量,因此我想将负值设为绿色,正值设为棕色。我可以导出svg并手动修改一切,但是这样做时,文本与框不完全对齐,因此手动更改颜色并修复所有文本会为我的工作流程添加非常繁琐的步骤,我真的很想避免!输入图像描述

另外,我看到过一些树,其中连接节点的线条长度与分裂解释的%成比例。如果可能的话,我也很想做到这一点?

1个回答

19
  • 你可以通过graph.get_edge_list()获取所有边的列表
  • 每个源节点应该有两个目标节点,编号较小的为True,编号较大的为False
  • 可以通过set_fillcolor()来设置颜色

输入图像描述

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data)

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

此外,我也看到过有些树的节点连接线长度与分割所解释的变异程度成比例。如果可能的话,我也很想这样做!

你可以尝试使用set_weight()set_len()函数进行调整,但这需要一些技巧和调整才能正确实现。以下是一些代码供参考。

for edge in edges:
    edges[edge].sort()
    src = graph.get_node(edge)[0]
    total_weight = int(src.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        weight = int(dest.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
        graph.get_edge(edge, str(edges[edge][0]))[0].set_weight((1 - weight / total_weight) * 100)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_len(weight / total_weight)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_minlen(weight / total_weight)

1
你会如何根据鸢尾花的类别来进行颜色适配? - MyopicVisage

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接