确实,关于rfmodel$pred
的内容,文档并不清晰——我敢打赌,包含的预测值是用作测试集的折叠,但我无法在文档中找到任何证据;尽管如此,无论如何,在您尝试获取ROC时,仍然存在一些问题。
首先,让我们将rfmodel$pred
从数据框中分离出来,以便更容易处理:
dd <- rfmodel$pred
nrow(dd)
为什么是450行?因为您尝试了3个不同的参数设置(在您的情况下只是mtry
的3个不同值):
rfmodel$results
mtry Accuracy Kappa AccuracySD KappaSD
1 2 0.96 0.94 0.04346135 0.06519202
2 3 0.96 0.94 0.04346135 0.06519202
3 4 0.96 0.94 0.04346135 0.06519202
150行 X 3设置 = 450。
让我们更仔细地查看rfmodel$pred
的内容:
head(dd)
pred obs setosa versicolor virginica rowIndex mtry Resample
1 setosa setosa 1.000 0.000 0 2 2 Fold1
2 setosa setosa 1.000 0.000 0 3 2 Fold1
3 setosa setosa 1.000 0.000 0 6 2 Fold1
4 setosa setosa 0.998 0.002 0 24 2 Fold1
5 setosa setosa 1.000 0.000 0 33 2 Fold1
6 setosa setosa 1.000 0.000 0 38 2 Fold1
- 列
obs
包含真实值
- 三列
setosa
、versicolor
和virginica
分别包含计算得到的每个类别的概率,它们对于每一行加起来总和为1
- 列
pred
包含最终的预测结果,即从上述三列中具有最大概率的类别
如果这就是全部内容,那么你绘制ROC曲线的方式就是可以的,即:
selectedIndices <- rfmodel$pred$Resample == "Fold1"
plot.roc(rfmodel$pred$obs[selectedIndices],rfmodel$pred$setosa[selectedIndices])
但这并不是整个故事(仅存在450行而不是150行就应该已经提示了):请注意存在一个名为mtry
的列;实际上,rfmodel$pred
包括交叉验证的所有运行结果(即针对所有参数设置的结果):
tail(dd)
pred obs setosa versicolor virginica rowIndex mtry Resample
445 virginica virginica 0 0.004 0.996 112 4 Fold5
446 virginica virginica 0 0.000 1.000 113 4 Fold5
447 virginica virginica 0 0.020 0.980 115 4 Fold5
448 virginica virginica 0 0.000 1.000 118 4 Fold5
449 virginica virginica 0 0.394 0.606 135 4 Fold5
450 virginica virginica 0 0.000 1.000 140 4 Fold5
这就是为什么你的
selectedIndices
计算不正确的终极原因;它还应该包括一个特定的
mtry
选择,否则ROC没有任何意义,因为它“聚合”了多个模型:
selectedIndices <- rfmodel$pred$Resample == "Fold1" & rfmodel$pred$mtry == 2
--
正如我在开头所说的那样,我打赌rfmodel$pred
中的预测是针对文件夹作为测试集的;实际上,如果我们手动计算准确率,它们将与上面显示的rfmodel$results
报告的准确率相一致(所有3个设置的准确率均为0.96),我们知道这些准确率是用作测试的文件夹所用的(可以说,相应的训练准确率为1.0):
for (i in 2:4) {
acc = (length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold1'))/30 +
length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold2'))/30 +
length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold3'))/30 +
length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold4'))/30 +
length(which(dd$pred == dd$obs & dd$mtry==i & dd$Resample=='Fold5'))/30
)/5
print(acc)
}
[1] 0.96
[1] 0.96
[1] 0.96