sklearn.svm.svc的predict_proba()函数在内部是如何工作的?

46

我正在使用scikit-learn中的sklearn.svm.svc进行二元分类。我使用它的predict_proba()函数来获取概率估计值。有人可以告诉我predict_proba()如何在内部计算概率吗?

2个回答

78

Scikit-learn使用LibSVM作为内部工具,而这个工具又使用Platt scaling进行校准,以生成概率预测结果和分类预测结果。此LibSVM作者的说明文档详细介绍了Platt scaling的过程。

Platt scaling首先要按照通常的方法训练SVM,然后对参数向量AB进行优化,使其满足以下条件:

P(y|X) = 1 / (1 + exp(A * f(X) + B))

其中f(X)是样本距离超平面的有符号距离(scikit-learn的decision_function方法)。您可能会在定义中看到逻辑sigmoid,这是逻辑回归和神经网络用于将决策函数转换为概率估计的相同函数。

请注意:参数B,即“截距”或“偏差”或任何您喜欢的名称,可能会导致基于该模型的概率估计进行的预测与从SVM决策函数f获得的预测不一致。例如,假设f(X) = 10,那么X的预测结果为正;但是如果B = -9.9A = 1,那么P(y|X) = .475。我是凭空编造这些数字的,但您已经注意到这在实践中可能会发生。

有效地说,Platt缩放是在交叉熵损失函数下,在SVM输出的基础上训练概率模型。为了防止该模型过度拟合,它使用内部五折交叉验证,这意味着使用probability=True训练SVM比普通的非概率性SVM更加昂贵。

2
很棒的回答@larsmans。我在想这些概率是否可以被解释为分类决策的置信度度量?例如,对于一个样本,正负类别的概率非常接近意味着学习器对其分类不太确定? - Moses Xu
2
谢谢@larsmans。我实际上观察到了更加引人注目的情况——预测为1,但概率为0.45。我认为贝叶斯最优截断值恰好是0.5。你认为这样引人注目的情况仍然可以通过LibSVM中的数值不稳定性来解释吗? - Moses Xu
1
@MosesXu:这是值得调查的事情,但我现在没有时间深入研究LibSVM代码。乍一看似乎存在不一致的行为,但我认为predict实际上并没有使用概率,而是使用了SVM超平面。 - Fred Foo
2
@MosesXu:我盯着这个数学问题看了一会儿,意识到通过适当的B值,你可以得到与SVM predictdecision_function方法所得到的预测结果非常不同的预测结果。我担心当你使用Platt缩放时,你必须要选择相信predict还是相信predict_proba,因为这两者可能是不一致的。 - Fred Foo
1
@MosesXu:我对这种行为没有任何理由,除了它是LibSVM的做法,而scikit-learn试图保持与其兼容。可能的原因是,probability=True不会影响decision_function的结果,所以无论如何都会存在不一致性。(我越想越相信Platt缩放只是一个hack,应该使用RVMs代替SVMs进行概率估计。) - Fred Foo
显示剩余5条评论

-1
实际上,我找到了一个稍微不同的答案,他们使用这段代码将决策值转换为概率。
'double fApB = decision_value*A+B;
if (fApB >= 0)
    return Math.exp(-fApB)/(1.0+Math.exp(-fApB));
else
     return 1.0/(1+Math.exp(fApB)) ;'

在模型文件中可以找到A和B的值(probA和probB)。 它提供了一种将概率转换为决策值,从而转换为铰链损失的方法。

使用ln(0) = -200。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接