Caret包自定义度量

13

我在我的一个项目中使用了caret函数的"train()",我想添加一个"自定义指标"F1-score。我查看了这个网址caret包,但我不明白如何使用可用的参数来构建这个评分。

这里有一个自定义指标的例子:

## Example with a custom metric
madSummary <- function (data,
lev = NULL,
model = NULL) {
out <- mad(data$obs - data$pred,
na.rm = TRUE)
names(out) <- "MAD"
out
}
robustControl <- trainControl(summaryFunction = madSummary)
marsGrid <- expand.grid(degree = 1, nprune = (1:10) * 2)
earthFit <- train(medv ~ .,
data = BostonHousing,
method = "earth",
tuneGrid = marsGrid,
metric = "MAD",
maximize = FALSE,
trControl = robustControl)
更新:
我尝试了你的代码,但问题是它不能处理多个类别,比如下面的代码(显示F1得分,但是它很奇怪)。我不确定,但我认为F1_score函数只能用于二进制类别。
library(caret)
library(MLmetrics)

set.seed(346)
dat <- iris

## See http://topepo.github.io/caret/training.html#metrics
f1 <- function(data, lev = NULL, model = NULL) {

print(data)
  f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs)
  c(F1 = f1_val)
}

# Split the Data into .75 input
in_train <- createDataPartition(dat$Species, p = .70, list = FALSE)

trainClass <- dat[in_train,]
testClass <- dat[-in_train,]



set.seed(35)
mod <- train(Species ~ ., data = trainClass ,
             method = "rpart",
             metric = "F1",
             trControl = trainControl(summaryFunction = f1, 
                                  classProbs = TRUE))

print(mod)

我也编写了一个手动F1分数的函数,只需要输入混淆矩阵即可。(我不确定在 "summaryFunction" 中是否可以使用混淆矩阵)

F1_score <- function(mat, algoName){

##
## Compute F1-score
##


# Remark: left column = prediction // top = real values
recall <- matrix(1:nrow(mat), ncol = nrow(mat))
precision <- matrix(1:nrow(mat), ncol = nrow(mat))
F1_score <- matrix(1:nrow(mat), ncol = nrow(mat))


for(i in 1:nrow(mat)){
  recall[i] <- mat[i,i]/rowSums(mat)[i]
  precision[i] <- mat[i,i]/colSums(mat)[i]
}

for(i in 1:ncol(recall)){
   F1_score[i] <- 2 * ( precision[i] * recall[i] ) / ( precision[i] + recall[i])
 }

 # We display the matrix labels
 colnames(F1_score) <- colnames(mat)
 rownames(F1_score) <- algoName

 # Display the F1_score for each class
 F1_score

 # Display the average F1_score
 mean(F1_score[1,])
}

不清楚您在过程的哪个部分遇到了问题。是关于编写自定义的summaryFunction还是在train输出中使用其结果。您能详细说明一下吗? - pbahr
它正在使用其在训练输出中的结果。在启动train()之后,我希望显示F1分数(仅准确度和cohen's kappa直接编码)。 - MarcelRitos
所以,如果您已经编写完成自定义函数,使用该函数编辑您的问题并发布可复现的示例将会有所帮助。 - pbahr
2个回答

22

你应该查看The caret Package - Alternate Performance Metrics以获取详细信息。一个可行的示例:

library(caret)
library(MLmetrics)

set.seed(346)
dat <- twoClassSim(200)

## See https://topepo.github.io/caret/model-training-and-tuning.html#metrics
f1 <- function(data, lev = NULL, model = NULL) {
  f1_val <- F1_Score(y_pred = data$pred, y_true = data$obs, positive = lev[1])
  c(F1 = f1_val)
}

set.seed(35)
mod <- train(Class ~ ., data = dat,
             method = "rpart",
             tuneLength = 5,
             metric = "F1",
             trControl = trainControl(summaryFunction = f1, 
                                      classProbs = TRUE))

这很好,因为我认为应该明确说明prSummary仅适用于二分类问题。 - NelsonGon
正类是否总是lev[1]?我在这里找不到它(https://topepo.github.io/caret/model-training-and-tuning.html#metrics) - user3226167
如果精度为NA,则使用自定义函数可能更好,因为F1 = NA...并且如果模型从未预测任何样本的“正”类(例如,0/0 = NA),则精度= NA。在这种情况下,您可以特殊处理您的F1函数以报告0,以避免在train()中出现错误。 - Brian D

1
对于两类情况,您可以尝试以下方法:
mod <- train(Class ~ ., 
             data = dat,
             method = "rpart",
             tuneLength = 5,
             metric = "F",
             trControl = trainControl(summaryFunction = prSummary, 
                                      classProbs = TRUE))

或者定义一个自定义摘要函数,结合两个类别的总结和prSummary当前最喜欢的函数,提供以下可能的评估指标 - AUROC、Spec、Sens、AUPRC、Precision、Recall、F。其中任何一个可以用作metric参数。这也包括我在接受答案评论中提到的特殊情况(F为NA)。
comboSummary <- function(data, lev = NULL, model = NULL) {
  out <- c(twoClassSummary(data, lev, model), prSummary(data, lev, model))

  # special case missing value for F
  out$F <- ifelse(is.na(out$F), 0, out$F)  
  names(out) <- gsub("AUC", "AUPRC", names(out))
  names(out) <- gsub("ROC", "AUROC", names(out))
  return(out)
}

mod <- train(Class ~ ., 
             data = dat,
             method = "rpart",
             tuneLength = 5,
             metric = "F",
             trControl = trainControl(summaryFunction = comboSummary, 
                                      classProbs = TRUE))



网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接