更改随机森林分类器的阈值值

5

我需要开发一个模型,该模型将免费(或接近免费)地消除假负值。为此,我绘制了召回率-精度曲线,并确定阈值应设置为0.11。

我的问题是,如何在模型训练时定义阈值?如果稍后在评估时定义它就没有意义,因为它不会反映新数据。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)

rfc_model = RandomForestClassifier(random_state=101)
rfc_model.fit(X_train, y_train)
rfc_preds = rfc_model.predict(X_test)


recall_precision_vals = []

for val in np.linspace(0, 1, 101):
    predicted_proba = rfc_model.predict_proba(X_test)
    predicted = (predicted_proba[:, 1] >= val).astype('int')
    
    recall_sc = recall_score(y_test, predicted)
    precis_sc = precision_score(y_test, predicted)

    recall_precision_vals.append({
        'Threshold': val,
        'Recall val': recall_sc,
        'Precis val': precis_sc
    })


recall_prec_df = pd.DataFrame(recall_precision_vals)

有什么想法吗?
1个回答

12
如何在模型训练期间定义阈值?
在模型训练期间没有阈值;随机森林是一种概率分类器,它只输出类别概率。"硬"类别(即0/1),确实需要阈值,但在模型训练的任何阶段都不会产生或使用 - 只在预测时使用,并且仅在确实需要硬分类的情况下使用(并非总是如此)。有关详细信息,请参见预测类别或类别概率?
实际上,即使对于硬类别预测,scikit-learn实现的RF也根本不使用阈值;仔细阅读predict方法的文档
“预测的类别是树中平均概率估计最高的类别”
简单地说,这意味着实际的RF输出是[p0, p1](假设是二元分类),其中predict方法只返回具有最高值的类别,即如果p0 > p1则返回0,否则返回1。
假设您实际想要做的是,如果p1大于某个小于0.5的阈值,则返回1。您需要放弃predict,改用predict_proba,然后操纵这些返回的概率以获得所需结果。以下是一个使用虚拟数据的示例:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

X, y = make_classification(n_samples=1000, n_features=4,
                          n_informative=2, n_redundant=0,
                           n_classes=2, random_state=0, shuffle=False)

clf = RandomForestClassifier(n_estimators=100, max_depth=2,
                            random_state=0)

clf.fit(X, y)

在这里,仅仅使用predict来预测X的第一个元素将会得到0的结果:
clf.predict(X)[0] 
# 0

因为

clf.predict_proba(X)[0]
# array([0.85266881, 0.14733119])

p0 > p1

为了得到您想要的结果(即返回类别1,因为对于阈值为0.11,p1 > threshold),您需要执行以下操作:

prob_preds = clf.predict_proba(X)
threshold = 0.11 # define threshold here
preds = [1 if prob_preds[i][1]> threshold else 0 for i in range(len(prob_preds))]

接下来,很容易看出现在我们对于第一个预测样本有:

preds[0]
# 1

由上面的示例可知,对于此样本,我们有p1 = 0.14733119 > 阈值


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接