Scikit - 改变阈值以创建多个混淆矩阵

11

我正在构建一个分类器,通过Lending Club数据选择最佳的X笔贷款。我已经训练了一个随机森林,并创建了常见的ROC曲线,混淆矩阵等。

混淆矩阵以分类器的预测结果(森林中树的大多数预测)作为参数。但是,我希望在不同的阈值下打印多个混淆矩阵,以了解如果我选择最好的10%贷款、20%贷款等会发生什么。

我从阅读其他问题中知道更改阈值通常不是一个好主意,但是否有其他方法可以查看这些情况下的混淆矩阵?(问题A)

如果我决定更改阈值,我应该假设最好的方法是预测概率,然后手动设置阈值,将其传递给混淆矩阵吗?(问题B)


是的,我认为唯一的方法是使用predict_proba,并手动更改阈值(或编写一个根据某些指标选择最佳阈值的函数)。在其他情况下可能不是一个好主意,但在这种情况下绝对是有意义的。 - amanbirs
1个回答

10

A. 针对您的情况,更改阈值是可行的,甚至可能是必要的。默认阈值为50%,但从业务角度来看,即使15%的违约概率可能足以拒绝这样的申请。

事实上,在信用评分中,通常会在使用共同模型(例如见Naeem Siddiqi的“Credit Risk Scorecards”第9章)预测违约概率后为不同的产品期限或客户段设置不同的截断点。

B. 有两种方便的方法可以将阈值设定为任意的alpha而不是50%:

  1. 确实,手动将predict_proba的输出结果设定为alpha,或使用包装器类进行设定(见下面的代码)。如果您想尝试多个阈值而不必重新拟合模型,则可以使用此方法。
  2. 在拟合模型之前将class_weights更改为(alpha,1-alpha)

现在,以下是一个包装器的示例代码:

import numpy as np
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.base import BaseEstimator, ClassifierMixin
X, y = make_classification(random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)

class CustomThreshold(BaseEstimator, ClassifierMixin):
    """ Custom threshold wrapper for binary classification"""
    def __init__(self, base, threshold=0.5):
        self.base = base
        self.threshold = threshold
    def fit(self, *args, **kwargs):
        self.base.fit(*args, **kwargs)
        return self
    def predict(self, X):
        return (self.base.predict_proba(X)[:, 1] > self.threshold).astype(int)

rf = RandomForestClassifier(random_state=1).fit(X_train, y_train)
clf = [CustomThreshold(rf, threshold) for threshold in [0.3, 0.5, 0.7]]

for model in clf:
    print(confusion_matrix(y_test, model.predict(X_test)))

assert((clf[1].predict(X_test) == clf[1].base.predict(X_test)).all())
assert(sum(clf[0].predict(X_test)) > sum(clf[0].base.predict(X_test)))
assert(sum(clf[2].predict(X_test)) < sum(clf[2].base.predict(X_test)))

它将针对不同的阈值输出3个混淆矩阵:

[[13  1]
 [ 2  9]]
[[14  0]
 [ 3  8]]
[[14  0]
 [ 4  7]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接