使用ggparty绘制决策树时,边缘的小数位数是多少?

3
我想使用强大的ggparty包绘制决策树(由partykit包估计),一切都很好,除了数值分割变量的小数位数。我该如何在geom_edge_label()中格式化breaks_label,例如在下面的图中将">75.33333"更改为">75.3"?round()无效。我可以通过通用的options(digits = 3)来使用解决方法,但我想知道是否有更直接的方法。
library("ggparty") 
data("WeatherPlay", package = "partykit")

sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75 + 1/3)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
    partynode(2L, split = sp_h, kids = list(
        partynode(3L, info = "yes"),
        partynode(4L, info = "no"))),
    partynode(5L, info = "yes"),
    partynode(6L, split = sp_w, kids = list(
        partynode(7L, info = "yes"),
        partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

ggparty(py) +
    geom_edge() +
    # geom_edge_label() +
    geom_edge_label(mapping = aes(label = paste(breaks_label))) +
    geom_node_splitvar() +
    geom_node_info()

这个示例是由reprex包(v0.3.0)在2020年03月05日创建的。


1
geom_edge_label(mapping = aes(label = paste(round(breaks_label,1))) 这个无法工作吗? - TTS
不行,它会出现“数学函数的非数值参数”的错误。问题似乎在于breaks_label是某种特殊对象,可以解析为能够添加(不)相等符号的对象。也请查看https://github.com/martin-borkovec/ggparty/wiki/5-geom_edge_label。 - hplieninger
这真的很棘手,你是正确的,它不起作用是因为pyplot$data$breaks_label,你得到了一个文本。我试图在party对象内部更改标签,但更改后无效。 - StupidWolf
1
最可能更普遍的解决方案是使用正则表达式并修改pyplot$data$breaks_label,然后重新绘制它..但我不知道这对您是否可行。 - StupidWolf
我明白了,修改pyplot$data$breaks_label中的字符串可以让我以一种非常间接的方式达到所需的解决方案。如果没有直接的解决方案出现,如果您愿意花时间并发布答案,我将接受您的答案。 - hplieninger
1个回答

3
感谢您使用ggparty!
所以我认为,对于当前版本来说,这是没有直接解决方案的事情。但我会确保在未来实现它!
通常,通过仅在节点的子集上使用geoms,可以解决很多问题。正如您已经注意到的那样,breaks_label并非以数字形式存储,而是作为字符存储,并带有一些可分析文本以表示不等号。因此,您需要使用类似substr()的函数。
ggparty(py) +
  geom_edge() +
  geom_edge_label(id = -c(3, 4)) +
    geom_edge_label(mapping = aes(label = paste(substr(breaks_label, start = 1, stop = 15))),
                    id = c(3, 4)) +
  geom_node_splitvar() +
  geom_node_info() 

我还修改了其中一个内部函数,加入舍入功能,您可以从github上获取并使用。但是我没有进行真正的测试,所以请自行承担风险 ;)

library(devtools)
source_url("https://raw.githubusercontent.com/martin-borkovec/ggparty/martin/R/add_splitvar_breaks_index_new.R")

rounded_labels <- add_splitvar_breaks_index_new(party_object = py,
                                                plot_data = ggparty:::get_plot_data(py), 
                                                round_digits = 2)

ggparty(py) +
  geom_edge() +
  geom_edge_label(mapping = aes(label = unlist(rounded_labels)),
                  data = rounded_labels) +
  geom_node_splitvar() +
  geom_node_info()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接