如何解释Spark逻辑回归预测中的概率列?

5
我正在使用 spark.ml.classification.LogisticRegressionModel.predict 进行预测。其中一些行的 prediction 列为 1.0,而 probability 列为 0.04。由于 model.getThreshold0.5,因此我认为该模型将概率阈值大于 0.5 的所有内容都分类为 1.0
那么,当结果显示 prediction1.0,而 probability0.04 时,应该如何解释呢?
1个回答

7

LogisticRegression的概率列应该包含一个与类别数目相同长度的列表,其中每个索引给出了该类别对应的概率。我举了一个包含两个类别的小例子作为说明:

case class Person(label: Double, age: Double, height: Double, weight: Double)
val df = List(Person(0.0, 15, 175, 67), 
      Person(0.0, 30, 190, 100), 
      Person(1.0, 40, 155, 57), 
      Person(1.0, 50, 160, 56), 
      Person(0.0, 15, 170, 56), 
      Person(1.0, 80, 180, 88)).toDF()

val assembler = new VectorAssembler().setInputCols(Array("age", "height", "weight"))
  .setOutputCol("features")
  .select("label", "features")
val df2 = assembler.transform(df)
df2.show

+-----+------------------+
|label|          features|
+-----+------------------+
|  0.0| [15.0,175.0,67.0]|
|  0.0|[30.0,190.0,100.0]|
|  1.0| [40.0,155.0,57.0]|
|  1.0| [50.0,160.0,56.0]|
|  0.0| [15.0,170.0,56.0]|
|  1.0| [80.0,180.0,88.0]|
+-----+------------------+

val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val Array(testing, training) = df2.randomSplit(Array(0.7, 0.3))

val model = lr.fit(training)
val predictions = model.transform(testing)
predictions.select("probability", "prediction").show(false)


+----------------------------------------+----------+
|probability                             |prediction|
+----------------------------------------+----------+
|[0.7487950501224138,0.2512049498775863] |0.0       |
|[0.6458452667523259,0.35415473324767416]|0.0       |
|[0.3888393314864866,0.6111606685135134] |1.0       |
+----------------------------------------+----------+

以下是算法给出的概率和最终预测结果。最终预测结果是概率最高的那一类。

3
嗨,Shaido。我们如何将每个概率与其对应的类相关联?正如我们所看到的,这里的prob[0]实际上来自类0,而prob[1]来自类1。哪里说prob[0]不对应于类1? - Kenny
好问题。在二元分类任务中,似乎有些算法默认将高于某个阈值的概率与类“0”相关联。这基本上是无用的... - guiotan
在二元分类的情况下,预测类别将始终是具有较高概率的类别。也就是说,在上面的代码中,“prediction” 0.0 和 1.0 分别对应于标签 0.0 和 1.0,而每个预测的相关概率将是“probability”列中最高的概率。 - Shaido
有没有一种方法可以在概率列中给出0到1之间的1个值?就像另一列(“combined_probability”)会导致前两个预测的概率低于0.5,而最后一行的概率为0.5或更高? - sAguinaga
1
@sAguinaga:你可以简单地取数组的第一个元素(如果你想要另一种顺序,可以取第二个)来得到你想要的内容。 - Shaido

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接