Scikit-learn分类:二项式对数回归?

3
我有一些文本,其评分在-100到+100的连续刻度上。我正在尝试将它们分类为积极或消极。
如何执行二项式logistic回归以获取测试数据为-100或+100的概率?
我最接近的是SGDClassifier(penalty ='l2',alpha = 1e-05,n_iter = 10),但当我使用二项式logistic回归来预测-100和+100的概率时,它与SPSS的结果不同。所以我猜这不是正确的函数?

也许在http://datascience.stackexchange.com上询问会更好? - Ashalynd
1
我认为sklearn的逻辑回归实现无法处理概率目标,它们只能处理标签。不过如果能够实现这一点就很棒了。 - eickenberg
2个回答

2
SGDClassifier提供了许多线性分类器,都是用随机梯度下降法训练的。除非您使用不同的损失函数,否则它将默认为线性支持向量机。loss = 'log'将提供概率逻辑回归。
请参阅文档: http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html#sklearn.linear_model.SGDClassifier 或者,您可以使用sklearn.linear_model.LogisticRegression使用逻辑回归对您的文本进行分类。
由于实现的差异,我不确定您是否会像使用SPSS一样获得完全相同的结果。但是,我不希望看到显着的差异。
编辑以添加:
我的怀疑是,您在SPSS逻辑回归中获得的99%准确度是训练集准确度,而您在scikits-learn逻辑回归中看到的87%是测试集准确度。我在datascience stack exchange上找到了这个问题,其中一个不同的人正在尝试极其相似的问题,并在训练集上获得了约99%的准确度和90%的测试集准确度。

https://datascience.stackexchange.com/questions/987/text-categorization-combining-different-kind-of-features

我的建议是:尝试使用scikits-learn中的几种基本分类器,包括标准的逻辑回归和线性支持向量机,并且多次使用不同的训练/测试数据子集重新运行SPSS逻辑回归,并比较结果。如果你发现各个分类器之间存在大的差异,而这些差异不能通过确保相似的训练/测试数据拆分来解释,那么请把你看到的结果发布在你的问题中,我们可以从那里开始进一步探讨。
祝你好运!

我相信我可以回答你的问题。如果可能的话,晚饭后;如果不行,那么等孩子们睡觉后再回复。 - brentlance
我认为我需要在这里进行一些澄清。这个数据集上的特征和标签是什么,你想要找到什么?我的理解是,这个数据集是文本数据,特征是一个+100到-100的范围,你想要使用逻辑回归,不仅仅是分类+1/-1,而且你还想知道每个文本属于每个类别的概率。这正确吗? - brentlance
是的,我正在尝试将数据进行二值分类,负类为-1,正类为+1。这些数据是已经根据从-100到+100的评分对推文进行了评级。推文的特征是基于TF-IDF算法的。我已经尝试了Scikit-learn中的所有分类器,但是准确率都没有在SPSS中得到的高。在Scikit-learn中,精度大约为0.87,在相同的数据和拆分上,而在SPSS中,我得到了0.99。我使用的SPSS方法尝试计算一条推文是-100或+100的概率,并提供一个概率值。 - Zach
你是如何将数据分割为训练/测试/验证集的?在SPSS下你是怎么做的?你是否比较训练集表现和测试集表现? - brentlance
另外,在SPSS中你使用了哪种类型的正则化? - brentlance
显示剩余4条评论

0
如果正/负或正概率是您所需的唯一输出,则可以将二进制标签y推导为:
y = score > 0

假设您已经将分数存储在一个NumPy数组score中。
然后,您可以将其提供给LogisticRegression实例,使用连续的分数来推导样本的相对权重:
clf = LogisticRegression()
sample_weight = np.abs(score)
sample_weight /= sample_weight.sum()
clf.fit(X, y, sample_weight)

这将最大权重赋予得分为±100的推文,并对标记为中性的推文赋予零权重,在两者之间线性变化。

如果数据集非常大,则可以像@brentlance所示那样使用SGDClassifier,但如果您想要逻辑回归模型,则必须给它loss="log";否则,您将获得一个线性SVM。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接