如何绘制多类分类器的精确度和召回率?

27

我正在使用scikit learn,并且想要绘制精确度和召回率曲线。我正在使用的分类器是RandomForestClassifier。在scikit learn文档中的所有资源都使用二元分类。而且,我能否绘制多类别的ROC曲线?

此外,我只发现SVM用于多标签,并且它有一个decision_function,但RandomForest没有。


这里有一个段落和示例:https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html。这不是你想要的吗? - Yohst
https://scikit-learn.org/0.15/auto_examples/plot_precision_recall.html - secretive
@Yohst,该示例使用具有决策函数的svm,而RandomForest没有决策函数。 - John Sall
1个回答

53

从scikit-learn文档中:

精度-召回曲线通常用于二元分类,以研究分类器的输出。为了将精度-召回曲线和平均精度扩展到多类或多标签分类,需要对输出进行二值化。每个标签可以绘制一条曲线,但也可以通过将标签指示矩阵的每个元素视为二进制预测(微平均)来绘制精度-召回曲线。

ROC曲线通常用于二元分类,以研究分类器的输出。为了将ROC曲线和ROC面积扩展到多类或多标签分类,需要对输出进行二值化。每个标签可以绘制一条ROC曲线,但也可以通过将标签指示矩阵的每个元素视为二进制预测(微平均)来绘制ROC曲线。

因此,您应该将输出二值化,并考虑每个类别的精确率-召回率和ROC曲线。此外,您将使用predict_proba来获取类概率。

我将代码分为三部分:

  1. 一般设置、学习和预测
  2. 精确率-召回率曲线
  3. ROC曲线

1. 一般设置、学习和预测

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_openml("mnist_784")
y = mnist.target
y = y.astype(np.uint8)
n_classes = len(set(y))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2. 精度-召回率曲线

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

图片描述文字

3. ROC 曲线

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

这里输入图片描述


3
为什么我要使用OneVsRestClassifier?难道RandomForest不支持多类分类吗? - John Sall
当我运行第一部分时,出现了以下错误: 用户警告:所有训练示例中都存在非0标签 用户警告:所有训练示例中都存在非1标签 用户警告:所有训练示例中都存在非2标签 - John Sall
请注意,警告并不是错误。考虑到这一行代码 Y = label_binarize(mnist.target, classes=[*range(n_classes)]),您应该在数据集中提供类别。在我的例子中,类别为 [0,1,2,...,9] - sentence
你如何使用微平均创建PR曲线或ROC曲线?据我所知,如果你有3个类别,你将获得3个概率向量,每个类别一个。然后观察结果被分配给具有最高概率的类别。也就是说,与阈值无关。但是对于ROC和PR曲线,你需要一个阈值,那么你如何进行微平均?你如何根据特定阈值将观察结果分配给类别? - Sole Galli
我刚刚尝试了反向计算精度和召回率,当阈值等于0时,看看它是否与classification_report()函数给出的结果相匹配,但它返回了奇怪不同的结果。我在这里解决这个问题:https://stats.stackexchange.com/questions/559203/why-does-precision-recall-curve-return-similar-but-not-equal-values-than-confu?noredirect=1#comment1028225_559203 - Federico Gentile
@JohnSall 请将 label_binarize 行替换为 Y = label_binarize(y, classes=[0,1,2,3,4,5,6,7,8,9])。 - Naseeb Gill

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接