如何将Spark DataFrame转换为RDD MLlib LabeledPoints?

14

我尝试对我的数据应用PCA,然后对转换后的数据应用RandomForest。但是,PCA.transform(data) 给了我一个 DataFrame,而我需要一个mllib LabeledPoints来提供给我的RandomForest。我该怎么做?我的代码:

    import org.apache.spark.mllib.util.MLUtils
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.ml.feature.PCA
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.linalg.Vectors


    val dataset = MLUtils.loadLibSVMFile(sc, "data/mnist/mnist.bz2")

    val splits = dataset.randomSplit(Array(0.7, 0.3))

    val (trainingData, testData) = (splits(0), splits(1))

    val trainingDf = trainingData.toDF()

    val pca = new PCA()
    .setInputCol("features")
    .setOutputCol("pcaFeatures")
    .setK(100)
    .fit(trainingDf)

    val pcaTrainingData = pca.transform(trainingDf)

    val numClasses = 10
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 10 // Use more in practice.
    val featureSubsetStrategy = "auto" // Let the algorithm choose.
    val impurity = "gini"
    val maxDepth = 20
    val maxBins = 32

    val model = RandomForest.trainClassifier(pcaTrainingData, numClasses, categoricalFeaturesInfo,
        numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)


     error: type mismatch;
     found   : org.apache.spark.sql.DataFrame
     required: org.apache.spark.rdd.RDD[org.apache.spark.mllib.regression.LabeledPoint]

我尝试了以下两种可能的解决方案,但它们没有起作用:

 scala> val pcaTrainingData = trainingData.map(p => p.copy(features = pca.transform(p.features)))
 <console>:39: error: overloaded method value transform with alternatives:
   (dataset: org.apache.spark.sql.DataFrame)org.apache.spark.sql.DataFrame <and>
   (dataset: org.apache.spark.sql.DataFrame,paramMap: org.apache.spark.ml.param.ParamMap)org.apache.spark.sql.DataFrame <and>
   (dataset: org.apache.spark.sql.DataFrame,firstParamPair: org.apache.spark.ml.param.ParamPair[_],otherParamPairs: org.apache.spark.ml.param.ParamPair[_]*)org.apache.spark.sql.DataFrame
  cannot be applied to (org.apache.spark.mllib.linalg.Vector)

而且:

     val labeled = pca
    .transform(trainingDf)
    .map(row => LabeledPoint(row.getDouble(0), row(4).asInstanceOf[Vector[Int]]))

     error: type mismatch;
     found   : scala.collection.immutable.Vector[Int]
     required: org.apache.spark.mllib.linalg.Vector

(在上面的情况下,我已经导入了org.apache.spark.mllib.linalg.Vectors)

需要帮忙吗?


1
您的代码对我来说完全正常(原样,不需要尝试两种解决方案)。 我猜可能是您其中一个导入出了问题? 我使用 import org.apache.spark.ml.feature.PCAimport org.apache.spark.mllib.util.MLUtils。 我使用此文件运行它:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2 - Tzach Zohar
@TzachZohar 哦,我和你一样导入了相同的内容,并通过编辑我的问题添加了它们。我也使用了相同的数据文件。是因为我在 shell 中运行而不是 spark-submit,所以它没有起作用吗? - Tianyi Wang
2
为什么会有这么多踩?这似乎是一个合理的问题。 - WestCoastProjects
1个回答

15

在这里正确的方法是你尝试过的第二种方法 - 将每个Row映射为LabeledPoint以获取一个RDD[LabeledPoint]。然而,它有两个错误:

  1. The correct Vector class (org.apache.spark.mllib.linalg.Vector) does NOT take type arguments (e.g. Vector[Int]) - so even though you had the right import, the compiler concluded that you meant scala.collection.immutable.Vector which DOES.
  2. The DataFrame returned from PCA.fit() has 3 columns, and you tried to extract column number 4. For example, showing first 4 lines:

    +-----+--------------------+--------------------+
    |label|            features|         pcaFeatures|
    +-----+--------------------+--------------------+
    |  5.0|(780,[152,153,154...|[880.071111851977...|
    |  1.0|(780,[158,159,160...|[-41.473039034112...|
    |  2.0|(780,[155,156,157...|[931.444898405036...|
    |  1.0|(780,[124,125,126...|[25.5114585648411...|
    +-----+--------------------+--------------------+
    

    To make this easier - I prefer using the column names instead of their indices.

所以,这里是您需要的转换:

val labeled = pca.transform(trainingDf).rdd.map(row => LabeledPoint(
   row.getAs[Double]("label"),   
   row.getAs[org.apache.spark.mllib.linalg.Vector]("pcaFeatures")
))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接