This is a follow-up question from 如何知道Scikit-learn中predict_proba返回数组中表示哪些类的问题。
在那个问题中,我引用了以下代码:
我发现在这个问题中,这个结果代表了每个类别的点属于该类别的概率,按照model.classes_给出的顺序。
现在,我记得在文档中读到过predict_proba对于小数据集不准确的内容,尽管我似乎找不到它了。这是预期的行为吗,还是我做错了什么?如果这是预期的行为,那么为什么predict和predict_proba函数会在输出上产生分歧?更重要的是,数据集需要多大才能信任predict_proba的结果?
-------- 更新 --------
好的,所以我对此进行了更多的“实验”:predict_proba的行为严重依赖于“n”,但没有任何可预测的方式!
在那个问题中,我引用了以下代码:
>>> import sklearn
>>> sklearn.__version__
'0.13.1'
>>> from sklearn import svm
>>> model = svm.SVC(probability=True)
>>> X = [[1,2,3], [2,3,4]] # feature vectors
>>> Y = ['apple', 'orange'] # classes
>>> model.fit(X, Y)
>>> model.predict_proba([1,2,3])
array([[ 0.39097541, 0.60902459]])
我发现在这个问题中,这个结果代表了每个类别的点属于该类别的概率,按照model.classes_给出的顺序。
>>> zip(model.classes_, model.predict_proba([1,2,3])[0])
[('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
所以……如果正确解释的话,这个答案表示这个点可能是一个“橙子”(由于数据量非常少,置信度较低)。但直觉上来看,这个结果显然是不正确的,因为给出的点与“苹果”的训练数据完全相同。为了确保,我也进行了反向测试:
>>> zip(model.classes_, model.predict_proba([2,3,4])[0])
[('apple', 0.60705475211840931), ('orange', 0.39294524788159074)]
再次强调,显然是错误的,但方向相反。
最后,我尝试了距离更远的点。
>>> X = [[1,1,1], [20,20,20]] # feature vectors
>>> model.fit(X, Y)
>>> zip(model.classes_, model.predict_proba([1,1,1])[0])
[('apple', 0.33333332048410247), ('orange', 0.66666667951589786)]
模型再次预测错误的概率。但是,model.predict函数得到了正确的结果!
>>> model.predict([1,1,1])[0]
'apple'
现在,我记得在文档中读到过predict_proba对于小数据集不准确的内容,尽管我似乎找不到它了。这是预期的行为吗,还是我做错了什么?如果这是预期的行为,那么为什么predict和predict_proba函数会在输出上产生分歧?更重要的是,数据集需要多大才能信任predict_proba的结果?
-------- 更新 --------
好的,所以我对此进行了更多的“实验”:predict_proba的行为严重依赖于“n”,但没有任何可预测的方式!
>>> def train_test(n):
... X = [[1,2,3], [2,3,4]] * n
... Y = ['apple', 'orange'] * n
... model.fit(X, Y)
... print "n =", n, zip(model.classes_, model.predict_proba([1,2,3])[0])
...
>>> train_test(1)
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
>>> for n in range(1,10):
... train_test(n)
...
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
n = 2 [('apple', 0.98437355278112448), ('orange', 0.015626447218875527)]
n = 3 [('apple', 0.90235408180319321), ('orange', 0.097645918196806694)]
n = 4 [('apple', 0.83333299908143665), ('orange', 0.16666700091856332)]
n = 5 [('apple', 0.85714254878984497), ('orange', 0.14285745121015511)]
n = 6 [('apple', 0.87499969631893626), ('orange', 0.1250003036810636)]
n = 7 [('apple', 0.88888844127886335), ('orange', 0.11111155872113669)]
n = 8 [('apple', 0.89999988018127364), ('orange', 0.10000011981872642)]
n = 9 [('apple', 0.90909082368682159), ('orange', 0.090909176313178491)]
我该如何在我的代码中安全使用这个函数?至少有哪些n的取值可以保证它与model.predict的结果一致?