Scikit-learn混淆矩阵中的阈值如何更改

7

我需要在二元分类器中使用不同阈值的多个混淆矩阵。

我已经查阅了各种资料,但未找到简单易行的实现方法。

有没有人可以提供一种设置scikit-learn混淆矩阵阈值的方法?

我知道scikit-learn的confusion_matrix使用0.5作为阈值。

model = LogisticRegression(random_state=0).fit(X_train, y_train)
y_pred = model.predict(X_test)
confusion_matrix(y_test, y_pred)
Output: array([[24705,     8],
              [  718,     0]])

谢谢!


1
什么是混淆矩阵的阈值?请给出您期望的输入和输出示例。 - Mathias Müller
1
@MathiasMüller 我已经添加了代码片段,基本上我正在使用scikitlearn的混淆矩阵方法,但是如何更改它的阈值呢? - Mel
1个回答

18

我简单地想通了:

threshold = 0.2
y_pred = (model.predict_proba(X_test)[:, 1] > threshold).astype('float')
confusion_matrix(y_test, y_pred)
希望这对于其他寻求简单更改阈值方法的人有所帮助!

1
我知道这已经很老了,但只是想说谢谢你分享你自己的答案。以不同阈值获取结果的方法简洁明了! - DataMonkey

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接