XGBOOST多类预测。预测矩阵是一组类别概率。如何执行混淆矩阵。

4

我已经使用XGBOOST进行多类别预测。

这是一个多标签预测问题。我的目标值包含8个类别,我使用了6个高度相关的特征。

我创建了我的预测数据集,并使用as.data.frame将其从矩阵转换为数据框。

我想检查我的预测准确性,但是由于列名更改且我的数据集中没有级别,所以不确定如何检查。我使用的所有数据类型都是整数和数值型的。

 Response <- train$Response
 label <- as.integer(train$Response)-1
 train$Response <- NULL

 train.index = sample(n,floor(0.75*n))
 train.data = as.matrix(train[train.index,])
 train.label = label[train.index]`
 test.data = as.matrix(train[-train.index,])
 test.label = label[-train.index]

 View(train.label)

 # Transform the two data sets into xgb.Matrix
 xgb.train = xgb.DMatrix(data=train.data,label=train.label)
 xgb.test = xgb.DMatrix(data=test.data,label=test.label)




  params = list(
          booster="gbtree",
          eta=0.001,
          max_depth=5,
          gamma=3,
          subsample=0.75,
          colsample_bytree=1,
          objective="multi:softprob",
          eval_metric="mlogloss",
          num_class=8)

    xgb.fit <-xgb.train(
    params=params,
    data=xgb.train,
    nrounds=10000,
    nthreads=1,
    early_stopping_rounds=10,
    watchlist=list(val1=xgb.train,val2=xgb.test),
    verbose=0
      )

   xgb.fit



  xgb.pred = predict(xgb.fit,test.data,reshape = T)
  class(xgb.pred)
  xgb.pred = as.data.frame(xgb.pred)

   """

现在我得到了以下形式的预测概率,由于有8类,我有8个概率值。我不知道哪个概率值属于哪个变量。

1   0.12233257  0.07373134  0.044682350 0.0810693502    0.06272415  0.134308174 0.066143863 0.415008187

我希望您能够将它们转换为有意义的标签,但我无法做到。执行混淆矩阵时需要这样做。

class(xgb.pred)是什么,为什么在train.data中包含响应变量? - StupidWolf
我问这个问题是因为我无法使用一些模拟数据重现你的错误。你能否也执行dput(head(train,20))并粘贴输出? - StupidWolf
dput(head(train,5)) 结构(list(Medical_History_23 = c(3L, 3L, 3L, 3L, 3L), Medical_Keyword_3 = c(0L, 0L, 0L, 0L, 0L), Medical_Keyword_15 = c(0L, 0L, 0L, 0L, 0L), BMI = c(0.323007976, 0.272287744, 0.428780429, 0.352437744, 0.424045645), Wt = c(0.148535565, 0.131799163, 0.288702929, 0.205020921, 0.234309623), Medical_History_4 = c(1L, 1L, 2L, 2L, 2L), Ins_Age = c(0.641791045, 0.059701493, 0.029850746, 0.164179104, 0.417910448), Response = c(8L, 4L, 8L, 8L, 8L )), row.names = c(NA, 5L), class = "data.frame") - Ardy
请编辑您的问题并包含数据样本(1行即可)和输出。在评论部分中导航很困难。 - MTT
@StupidWolf,我刚刚更新了问题。我在混淆矩阵方面遇到了麻烦。我有8个类别,所以xgb.pred给了我8个概率。我不确定哪个概率属于哪个类别。如果我能解码并将最大概率分配给单个类别,我就可以继续做混淆矩阵了。也许我错过了一两个步骤。你能帮忙吗?先谢谢了。 - Ardy
@MTT 更新了这个问题 - Ardy
2个回答

4
假设你的数据长这样:
train = data.frame(
  Medical_History_23 = sample(1:5,2000,replace=TRUE), 
  Medical_Keyword_3 = sample(1:5,2000,replace=TRUE), 
  Medical_Keyword_15 = sample(1:5,2000,replace=TRUE), 
  BMI = rnorm(2000), 
  Wt = rnorm(2000), 
  Medical_History_4 = sample(1:5,2000,replace=TRUE), 
  Ins_Age = rnorm(2000), 
  Response = sample(1:8,2000,replace=TRUE)) 

我们进行训练和测试:

library(xgboost)
label <- as.integer(train$Response)-1
train$Response <- NULL
n = nrow(train)
train.index = sample(n,floor(0.75*n))
train.data = as.matrix(train[train.index,])
train.label = label[train.index]
test.data = as.matrix(train[-train.index,])
test.label = label[-train.index]
xgb.train = xgb.DMatrix(data=train.data,label=train.label)
xgb.test = xgb.DMatrix(data=test.data,label=test.label)

params = list(booster="gbtree",eta=0.001,
          max_depth=5,gamma=3,subsample=0.75,
          colsample_bytree=1,objective="multi:softprob",
          eval_metric="mlogloss",num_class=8)

xgb.fit <-xgb.train(params=params,data=xgb.train,
    nrounds=10000,nthreads=1,early_stopping_rounds=10,
    watchlist=list(val1=xgb.train,val2=xgb.test),
    verbose=0
      )

xgb.pred = predict(xgb.fit,test.data,reshape = T)

你的预测如下,每列是1、2...8的概率。

> head(xgb.pred)
         V1        V2        V3        V4        V5        V6        V7        V8
1 0.1254475 0.1252269 0.1249843 0.1247929 0.1246919 0.1248430 0.1248226 0.1251909
2 0.1255558 0.1249674 0.1250741 0.1250397 0.1249939 0.1247931 0.1248649 0.1247111
3 0.1249737 0.1250508 0.1249501 0.1250445 0.1250142 0.1249630 0.1249194 0.1250844

为了获得预测标签,我们需要:
predicted_labels= factor(max.col(xgb.pred),levels=1:8)
obs_labels = factor(test.label,levels=1:8)

获取混淆矩阵的方法如下:
caret::confusionMatrix(obs_labels,predicted_labels)

当然,我这个例子的准确度会很低,因为变量中没有有用的信息,但是代码应该对你有用。


1
与您的标签相同的顺序。例如:

0.415008187

这是发生第八类等事件的概率。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接