在caret中获取保留折叠的预测结果

3

我想知道如何恢复交叉验证的预测结果。我有兴趣手动构建一个堆叠模型(像这里的3.2.1点), 我需要每个保留折叠的模型预测结果。这里附上一个简短的示例。

# load the library
library(caret)
# load the iris dataset
data(cars)
# define folds
cv_folds <- createFolds(cars$Price, k = 5, list = TRUE)
# define training control
train_control <- trainControl(method="cv", index = cv_folds, savePredictions = 'final')
# fix the parameters of the algorithm
# train the model
model <- caret::train(Price~., data=cars, trControl=train_control, method="gbm", verbose = F)
# looking at predictions
model$pred

# verifying the number of observations
nrow(model$pred[model$pred$Resample == "Fold1",])
nrow(cars)

我想知道在对1-4折叠的模型进行估计并在第5折上进行评估时,会有哪些预测结果。看起来,查看model$pred并不能给我所需的信息。

1
你可能想看看cross-validated上的这个答案 - phiver
我看到了这个答案,但是还有一些疑问。当我运行 subset(model$pred, Resample == "Fold01") 时,我得到的数据框大小等于用于估计的9个折叠。然而,我想要得到的是相反的,也就是未用于估计的10%数据的预测结果。 - abu
@abu 这是不正确的:根据您的示例:nrow(model$pred[model$pred$Resample == "Fold01",]) 是 80,而 nrow(cars) 是 804。 - missuse
@missuse 你说得对,我确实没有在可重现的示例中验证过这一点,而是在我的更复杂的示例中。在那种情况下,我遇到了我在先前评论中提到的问题。通过将你的代码应用于我的5折交叉验证,我获得了与训练数据80%相等的nrow - abu
@missuse 我修改了问题,使它更接近我的情况,现在它显示了我80%的数据所表示的意思。 - abu
1个回答

3

在使用caret进行CV时,如果使用createFolds函数创建的折叠,训练索引将会被默认使用。因此,当您执行以下操作时:

cv_folds <- createFolds(cars$Price, k = 5, list = TRUE)

你收到了训练集折叠。
lengths(cv_folds)
#output
Fold1 Fold2 Fold3 Fold4 Fold5 
  161   160   161   160   162

每个包含您数据的20%

然后您可以在trainControl中指定这些折叠:

train_control <- trainControl(method="cv", index = cv_folds, savePredictions = 'final')

来自trainControl的帮助:

index - 一个列表,每个重抽样迭代都有一个元素。每个列表元素是一个整数向量,对应于该迭代中用于训练的行。

indexOut - 一个列表(与index长度相同),它指定了每个重抽样保留哪些数据(作为整数)。如果为NULL,则使用不包含在index中的唯一样本集。

因此,每次模型都是在160行上构建并在其余部分上进行验证。这就是为什么

nrow(model$pred[model$pred$Resample == "Fold1",])

返回 643

您需要做的是:

cv_folds <- createFolds(cars$Price, k = 5, list = TRUE, returnTrain = TRUE)

现在:

lengths(cv_folds)
#output
Fold1 Fold2 Fold3 Fold4 Fold5 
  644   643   642   644   643 

在训练模型之后:

nrow(model$pred[model$pred$Resample == "Fold1",])
#output
160

顺便问一下,你知道为什么当我筛选“Fold05”时,会得到在我生成的“Fold01”中的索引吗?我的意思是为什么会有某种偏移。 - abu
@abu,我可能没有理解你的意思,或者你说的不正确。当我运行以下代码时:lapply(unique(model$pred$Resample), function(x){ sum(model$pred$rowIndex[model$pred$Resample == x] %in% model$pred$rowIndex[model$pred$Resample != x]) }) 我得到了全部为0的结果,因此如果您能发布另一个带有问题示例的问题,我很乐意尝试回答。 - missuse
我觉得我们彼此之间存在一些误解。我已经发布了另一个问题: 链接。如果您能再给我提供一些帮助,我将非常感激 :) - abu
@敏捷Bean的预测结果存储在model$pred中。$finalModel包含使用最佳调整参数拟合所有训练数据的实际拟合模型。 - missuse
@missuse 谢谢。所以,为了澄清,您上面的计算与 model$pred 是相同的吗? - Agile Bean
显示剩余3条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接