我的目标是构建一个多分类器。
我已经建立了一个特征提取的流程,并且第一步包括使用StringIndexer转换器来将每个类别名称映射到一个标签,该标签将在分类器训练阶段使用。
这个流程被用于对训练集进行拟合。
测试集必须经过拟合后的流程处理,以便提取相同的特征向量。
考虑到我的测试集文件与训练集具有相同的结构。可能的情况是在测试集中遇到未知的类别名称,在这种情况下,StringIndexer将无法找到标签,并引发异常。
是否有解决方案?或者我们如何避免这种情况发生?
我的目标是构建一个多分类器。
我已经建立了一个特征提取的流程,并且第一步包括使用StringIndexer转换器来将每个类别名称映射到一个标签,该标签将在分类器训练阶段使用。
这个流程被用于对训练集进行拟合。
测试集必须经过拟合后的流程处理,以便提取相同的特征向量。
考虑到我的测试集文件与训练集具有相同的结构。可能的情况是在测试集中遇到未知的类别名称,在这种情况下,StringIndexer将无法找到标签,并引发异常。
是否有解决方案?或者我们如何避免这种情况发生?
在Spark 2.2(于2017年发布)中,您可以在创建索引器时使用.setHandleInvalid("keep")
选项。使用此选项,当索引器看到新的标签时,它会添加新的索引。
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.setHandleInvalid("keep") // options are "keep", "error" or "skip"
根据文档,在将StringIndexer拟合于一个数据集并用于转换另一个数据集时,有三种处理未知标签的策略:
请参见链接的文档,了解StringIndexer不同选项的输出示例。
保留
,错误
,跳过
选项之间确切差异的更多详细信息。 跳过
是删除数据点吗? 错误
会中断模型吗?而保留
则会添加新列?这是否适用于两种情况,即测试集中未在训练集中看到的值和训练集中未在测试集中看到的值? - Chuck在Spark 1.6中可以绕过此问题。
这是jira: https://issues.apache.org/jira/browse/SPARK-8764
以下是一个示例:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.setHandleInvalid("skip") // new method. values are "error" or "skip"
我开始使用这个,但最终回到了KrisP的第二个要点,即将特定的估算器拟合到完整数据集上。
在将IndexToString转换时,您稍后需要使用它作为一部分管道。
这是修改后的示例:
val categoryIndexerModel = new StringIndexer()
.setInputCol("category")
.setOutputCol("indexedCategory")
.fit(itemsDF) // Fit the Estimator and create a Model (Transformer)
... do some kind of classification ...
val categoryReverseIndexer = new IndexToString()
.setInputCol(classifier.getPredictionCol)
.setOutputCol("predictedCategory")
.setLabels(categoryIndexerModel.labels) // Use the labels from the Model
很抱歉,没有更好的方法。你可以选择:
StringIndexer
之前过滤掉具有未知标签的测试示例。StringIndexer
拟合于训练和测试数据框的联合体中,以确保所有标签都存在。这里是一些执行上述操作的示例代码:
// get training labels from original train dataframe
val trainlabels = traindf.select(colname).distinct.map(_.getString(0)).collect //Array[String]
// or get labels from a trained StringIndexer model
val trainlabels = simodel.labels
// define an UDF on your dataframe that will be used for filtering
val filterudf = udf { label:String => trainlabels.contains(label)}
// filter out the bad examples
val filteredTestdf = testdf.filter( filterudf(testdf(colname)))
// transform unknown value to some value, say "a"
val mapudf = udf { label:String => if (trainlabels.contains(label)) label else "a"}
// add a new column to testdf:
val transformedTestdf = testdf.withColumn( "newcol", mapudf(testdf(colname)))