使用Rpart软件生成的测试规则

7
我希望以编程的方式测试从树中生成的一条规则。在树中,从根到叶子节点(终端节点)的路径可以被解释为一条规则。
在R中,我们可以使用“rpart”包并执行以下操作: (在本文中,我将仅使用“鸢尾花”数据集作为示例目的)
library(rpart)
model <- rpart(Species ~ ., data=iris)

用这两行代码,我得到了一个名为model的树形结构,它的类是rpart.objectrpart文档第21页)。这个对象有很多信息,并支持各种方法。特别是,这个对象有一个frame变量(可以通过标准方式访问:model$frame)(同上),以及方法path.rpathrpart文档第7页),它可以给出从根节点到感兴趣的节点(函数中的node参数)的路径。 frame变量的row.names包含树的节点编号。var列给出了节点中的分割变量,yval是拟合值,yval2是类概率和其他信息。
> model$frame
           var   n  wt dev yval complexity ncompete nsurrogate     yval2.1     yval2.2     yval2.3     yval2.4     yval2.5     yval2.6     yval2.7
1 Petal.Length 150 150 100    1       0.50        3          3  1.00000000 50.00000000 50.00000000 50.00000000  0.33333333  0.33333333  0.33333333
2       <leaf>  50  50   0    1       0.01        0          0  1.00000000 50.00000000  0.00000000  0.00000000  1.00000000  0.00000000  0.00000000
3  Petal.Width 100 100  50    2       0.44        3          3  2.00000000  0.00000000 50.00000000 50.00000000  0.00000000  0.50000000  0.50000000
6       <leaf>  54  54   5    2       0.00        0          0  2.00000000  0.00000000 49.00000000  5.00000000  0.00000000  0.90740741  0.09259259
7       <leaf>  46  46   1    3       0.01        0          0  3.00000000  0.00000000  1.00000000 45.00000000  0.00000000  0.02173913  0.97826087

但是,仅在“var”列中标记为<leaf>的节点是终端节点(叶子)。在这种情况下,节点是2、6和7。
如上所述,您可以使用path.rpart方法提取规则(此技术用于rattle包和文章Sharma Credit Score中),如下所示:
此外,该模型将预测值的值保存在中。
predicted.levels <- attr(model, "ylevels")

这个值对应于 model$frame 数据集中的列 yval

对于节点编号为7的叶子(行号为5),预测值为

> ylevels[model$frame[5, ]$yval]
[1] "virginica"

规则是:

> rule <- path.rpart(model, nodes = 7)

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75

因此,规则可被解读为:
If Petal.Length >= 2.45 AND Petal.Width >= 1.75 THEN Species = Virginica

我知道我可以在测试数据集中(我将再次使用鸢尾花数据集)测试此规则的真正阳性数量,将新数据集子集化如下:

> hits <- subset(iris, Petal.Length >= 2.45 & Petal.Width >= 1.75)

然后计算混淆矩阵

> table(hits$Species, hits$Species == "virginica")

             FALSE TRUE
  setosa         0    0
  versicolor     1    0
  virginica      0   45

(注:我使用相同的鸢尾花数据集进行测试)
如何以编程方式评估规则?可以从规则中提取条件,如下所示。
> unlist(rule, use.names = FALSE)[-1]
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

但是,我该如何从这里继续呢?我不能使用subset函数。

提前感谢。

注意:为了更好的清晰度,此问题已进行了大量编辑。


这个问题很快就会被关闭,因为你并没有构建一个符合这些指南的问题,或者至少没有构建一个清晰的问题。现在进行一些快速编辑来整理它还不算太晚。 - Tyler Rinker
谢谢您的评论,我编辑了问题,现在可能更清晰了吗? - nanounanue
上面代码中的关键部分是 rule <- path.rpart(model, nodes=node.number, print.it=FALSE),它返回一个列表,其中包含 [1] checking < 2.5 [2] afford< 54 等内容。因此,我想要的是类似于 true.positives <- length(test.data[rule]) 的东西,显然,这段代码不起作用。但是,思路已经在那里了... 有什么想法吗? - nanounanue
对我来说,这个问题无法重现。很容易找到德国信用评分数据。实际上太容易了,大约有6个不同的版本。当我在最受支持的版本上使用rpart时,我得到的结构与您似乎得到的结构不同。例如,没有model$frame$yval2值。因此,您必须除了链接代码中的内容之外,还做了其他事情。 - IRTFM
嗯,这很奇怪...阅读rpart文档时,rpart.object内的一个变量是frame,而在frame内部则有yval2变量。我认为我需要重新阐述问题并提供更清晰的示例... - nanounanue
3个回答

3
我可以用以下方法解决这个问题。
免责声明:显然有更好的解决方法,但这种hack方法有效并且实现了我想要的功能...(我不是很自豪...它是hackish的,但有效)
好的,让我们开始。基本上的想法是使用sqldf包。
如果你查看问题,最后一段代码将树的每个路径部分放入列表中。所以,我将从那里开始。
        library(sqldf)
        library(stringr)

        # Transform to a character vector
        rule.v <- unlist(rule, use.names=FALSE)[-1]
        # Remove all the dots, sqldf doesn't handles dots in names 
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z])\\.([a-zA-Z])", replacement="\\1_\\2")
        # We have to remove all the equal signs to 'in ('
        rule.v <- str_replace_all(rule.v, pattern="([a-zA-Z0-9])=", replacement="\\1 in ('")
        # Embrace all the elements in the lists of values with " ' " 
        # The last element couldn't be modified in this way (Any ideas?) 
        rule.v <- str_replace_all(rule.v, pattern=",", replacement="','")

        # Close the last element with apostrophe and a ")" 
        for (i in which(!is.na(str_extract(pattern="in", string=rule.v)))) {
          rule.v[i] <- paste(append(rule.v[i], "')"), collapse="")
        }

        # Collapse all the list in one string joined by " AND "
        rule.v <- paste(rule.v, collapse = " AND ")

        # Generate the query
        # Use any metric that you can get from the data frame
        query <- paste("SELECT Species, count(Species) FROM iris WHERE ", rule.v, " group by Species", sep="")

        # For debug only...
        print(query)

        # Execute and print the results
        print(sqldf(query))

就是这样了!

我提醒过你,这很粗糙...

希望这能帮到其他人...

感谢所有的帮助和建议!


在将此问题标记为已回答之前,我会等待有更好的答案(或更优雅的答案)的人。 - nanounanue
由于没有人提供更好或更优雅的解决方案,我将把这个答案标记为我的问题的答案。显然,如果有更好的解决方案,我会更改它...再次感谢! - nanounanue

2
一般来说,我不建议使用eval(parse(...)),但在这种情况下似乎可以使用:
提取规则:
rule <- unname(unlist(path.rpart(model, nodes=7)))[-1]

 node number: 7 
   root
   Petal.Length>=2.45
   Petal.Width>=1.75
rule
[1] "Petal.Length>=2.45" "Petal.Width>=1.75" 

使用规则提取数据:

node_data <- with(iris, iris[eval(parse(text=paste(rule, collapse=" & "))), ])
head(node_data)

    Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
71           5.9         3.2          4.8         1.8 versicolor
101          6.3         3.3          6.0         2.5  virginica
102          5.8         2.7          5.1         1.9  virginica
103          7.1         3.0          5.9         2.1  virginica
104          6.3         2.9          5.6         1.8  virginica
105          6.5         3.0          5.8         2.2  virginica

1

从以下开始

Rule number: 16 [yval=bad cover=220 N=121 Y=99 (37%) prob=0.04]
checking< 2.5
afford< 54
history< 3.5
coapp< 2.5

你会有一个名为“prob”的向量,最初全部为零,你可以使用rule16进行更新:

prob <- ifelse( dat[['checking']] < 2.5 &
                dat[['afford']]  < 54
                dat[['history']] < 3.5
                dat[['coapp']]  < 2.5) , 0.04, prob )

然后,您需要运行所有其他规则(对于此情况,应该不会改变任何概率,因为树应该是不相交的估计)。构建预测的方法可能比这更有效。例如...使用predict.rpart函数。


谢谢你的帮助@DWin,但有两点:第一,我想在测试数据集中只测试一个规则,因此我认为predict.rpart在这里没有用处。第二,我想以编程方式完成它。我会编辑问题以反映这一点。 - nanounanue

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接