交叉验证 CART 模型

3

在一个任务中,我们需要对CART模型进行交叉验证。我尝试使用来自cvToolscvFit函数,但得到了一个奇怪的错误信息。下面是一个最简示例:

library(rpart)
library(cvTools)
data(iris)
cvFit(rpart(formula=Species~., data=iris))

我看到的错误是:
Error in nobs(y) : argument "y" is missing, with no default

接下来是 traceback()

5: nobs(y)
4: cvFit.call(call, data = data, x = x, y = y, cost = cost, K = K, 
       R = R, foldType = foldType, folds = folds, names = names, 
       predictArgs = predictArgs, costArgs = costArgs, envir = envir, 
       seed = seed)
3: cvFit(call, data = data, x = x, y = y, cost = cost, K = K, R = R, 
       foldType = foldType, folds = folds, names = names, predictArgs = predictArgs, 
       costArgs = costArgs, envir = envir, seed = seed)
2: cvFit.default(rpart(formula = Species ~ ., data = iris))
1: cvFit(rpart(formula = Species ~ ., data = iris))

看起来cvFit.default需要强制使用y。但是:

> cvFit(rpart(formula=Species~., data=iris), y=iris$Species)
Error in cvFit.call(call, data = data, x = x, y = y, cost = cost, K = K,  : 
  'x' must have 0 observations

我做错了什么?哪个软件包可以让我使用CART树进行交叉验证,而无需自己编写代码呢?(我真的很懒...)

3
如果你深入研究 cvTools 的文档,会发现其中大部分工具都是针对连续响应变量而非离散的构建的。尽管可以尝试让它运行,但看起来你需要自己提供一个函数给 cost 以计算分类误差。 - joran
@joran:没错——谢谢!请参见我的回答 - krlmlr
2个回答

17

使用Caret包可以轻松进行交叉验证:

> library(caret)
> data(iris)
> tc <- trainControl("cv",10)
> rpart.grid <- expand.grid(.cp=0.2)
> 
> (train.rpart <- train(Species ~., data=iris, method="rpart",trControl=tc,tuneGrid=rpart.grid))
150 samples
  4 predictors
  3 classes: 'setosa', 'versicolor', 'virginica' 

No pre-processing
Resampling: Cross-Validation (10 fold) 

Summary of sample sizes: 135, 135, 135, 135, 135, 135, ... 

Resampling results

  Accuracy  Kappa  Accuracy SD  Kappa SD
  0.94      0.91   0.0798       0.12    

Tuning parameter 'cp' was held constant at a value of 0.2

1
哇,看看“train”中支持的方法列表。这就是我所说的全面......这里有很多“神奇”的事情发生。是否可能仅访问交叉验证例程,而无需实际优化模型参数呢? - krlmlr
我不这么认为,但你可以定义自己的参数网格。如果你不想测试多个模型,那么它们可以设置为静态值。我将通过编辑上面的示例来说明这一点。 - David
什么是插入符号?我没有看到它在你的回答中被使用。 - stackoverflowuser2010
我忘记在代码中包含一个库,已经进行了编辑,现在应该没问题了。 - David

4

最终,我成功使其运行。如Joran所指出的那样,cost参数需要进行调整。在我的情况下,我使用0/1损失函数,这意味着我使用一个简单的函数来评估yyHat之间的!=而不是-。此外,predictArgs必须包括c(type='class'),否则内部使用的predict调用将返回一个概率向量而不是最可能的分类。总之:

library(rpart)
library(cvTools)
data(iris)
cvFit(rpart, formula=Species~., data=iris,
      cost=function(y, yHat) (y != yHat) + 0, predictArgs=c(type='class'))

这里使用了另一种cvFit的变体。可以通过设置args=参数来传递给rpart的其他参数。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接