使用Python序列化自定义转换器,以便在Pyspark ML流水线中使用。

28
6个回答

40
从Spark 2.3.0开始,有一种更好的方法来实现此操作。只需扩展DefaultParamsWritableDefaultParamsReadable,您的类将自动拥有writeread方法,这些方法将保存您的参数,并被PipelineModel序列化系统使用。文档不是很清楚,我不得不阅读源代码才理解反序列化是如何工作的。
  • PipelineModel.read实例化了一个PipelineModelReader
  • PipelineModelReader加载元数据并检查语言是否为'Python'。如果不是,则使用典型的JavaMLReader(大多数答案都是为此设计的)
  • 否则,使用PipelineSharedReadWrite,它调用DefaultParamsReader.loadParamsInstance

loadParamsInstance将从保存的元数据中找到class。它将实例化该类并在其上调用.load(path)。您可以扩展DefaultParamsReader并自动获得DefaultParamsReader.load方法。如果您确实需要实现特殊的反序列化逻辑,我建议以该load方法作为起点。

相反的情况:

  • PipelineModel.write将检查所有阶段是否为Java(实现JavaMLWritable)。如果是,则使用典型的JavaMLWriter(大多数答案都是为此设计的)。
  • 否则,将使用PipelineWriter,该编写器检查所有阶段是否实现了MLWritable并调用PipelineSharedReadWrite.saveImpl
  • PipelineSharedReadWrite.saveImpl将在每个阶段上调用.write().save(path)

您可以扩展DefaultParamsWriter以获取DefaultParamsWritable.write方法,该方法以正确格式保存类和参数的元数据。如果您需要实现自定义序列化逻辑,我建议您从那里开始查看DefaultParamsWriter

好的,最后,您拥有一个非常简单的转换器,它扩展了Params,并且所有参数都以典型的Params方式存储:

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform

class SetValueTransformer(
    Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
    value = Param(
        Params._dummy(),
        "value",
        "value to fill",
    )

    @keyword_only
    def __init__(self, outputCols=None, value=0.0):
        super(SetValueTransformer, self).__init__()
        self._setDefault(value=0.0)
        kwargs = self._input_kwargs
        self._set(**kwargs)

    @keyword_only
    def setParams(self, outputCols=None, value=0.0):
        """
        setParams(self, outputCols=None, value=0.0)
        Sets params for this SetValueTransformer.
        """
        kwargs = self._input_kwargs
        return self._set(**kwargs)

    def setValue(self, value):
        """
        Sets the value of :py:attr:`value`.
        """
        return self._set(value=value)

    def getValue(self):
        """
        Gets the value of :py:attr:`value` or its default value.
        """
        return self.getOrDefault(self.value)

    def _transform(self, dataset):
        for col in self.getOutputCols():
            dataset = dataset.withColumn(col, lit(self.getValue()))
        return dataset

现在我们可以使用它:
from pyspark.ml import Pipeline, PipelineModel

svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)

p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()

结果:

+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

matches? True
+---+-----+-----+-----+
|key|value|    a|    b|
+---+-----+-----+-----+
|  1| null|123.0|123.0|
|  2|  1.0|123.0|123.0|
|  3|  0.5|123.0|123.0|
+---+-----+-----+-----+

7
当使用Pipeline.load(filename)时,如果出现错误AttributeError: module '__main__' has no attribute 'YourTransformerClass',请尝试通过执行m = __import__("__main__"); setattr(m, 'YourTransformerClass', YourTransformerClass)在当前的__main__模块中注册YourTransformerClass。可以查看DefaultParamsReader的源代码。 - mvillegas
我也遇到了同样的错误。保存已完成,但在加载时出现错误。 - yogesh agrawal
问题:为什么您在setParams()方法中放置默认值? - George C
@GeorgeC 的 setParams 是关键字参数装饰器的输入;装饰器需要默认值,并且关键字参数将保存为属性 _input_kwargs。 - hwrd
有没有一种方法可以将转换器序列化,使其在管道中“自包含”,从而避免在加载存储的管道之前加载类代码?我也遇到了@mvillegas描述的错误,不必每次加载存储的管道时都运行类代码会很好。 - Nic Scozzaro
@NicScozzaro 我还没有找到这样的东西,但我很愿意学习其他方面的知识。 - Benjamin Manns

14

我不确定这是否是最佳方法,但我也需要能够保存我在Pyspark中创建的自定义评估器、转换器和模型,并支持它们在管道API中的持久使用。自定义的Pyspark评估器、转换器和模型可以在管道API中创建和使用,但无法保存。当模型训练时间超过事件预测周期时,这会带来问题。

Pyspark评估器、转换器和模型通常只是Java或Scala等效项的包装器,而Pyspark包装器只是通过py4j将参数从Python传递到Java并反之。然后在Java端执行任何模型持久化。由于当前结构,这限制了自定义Pyspark评估器、转换器和模型仅存在于Python世界中。

在之前的尝试中,我能够通过使用Pickling/dill序列化来保存单个Pyspark模型。这很有效,但仍无法允许从管道API中保存或加载其它内容。但是,受到另一篇SO帖子的指引,我被引导到OneVsRest分类器,并检查了_to_java和_from_java方法。它们在Pyspark方面完成了所有繁重的工作。看完后,我想,如果有一种方法将pickle转储保存到一个已经制作并支持可保存的Java对象中,则应该能够使用管道API保存自定义Pyspark评估器、转换器和模型。

为此,我发现StopWordsRemover是一个理想的对象,因为它具有一个名为stopwords的字符串列表属性。dill.dumps方法返回对象的pickled表示形式字符串。计划是将字符串转换为列表,然后将StopWordsRemover的stopwords参数设置为该列表。虽然是字符串列表,但我发现其中一些字符无法传递给Java对象。因此,将字符转换为整数,然后将整数转换为字符串。这对于保存单个实例以及在Pipeline中保存时都非常有效,因为Pipeline忠实地调用了我的Python类的_to_java方法(我们仍然处于Pyspark方面,因此这很有效)。但是,在管道API中从Java回到Pyspark时则不行。

因为我将Python对象隐藏在StopWordsRemover实例中,所以当管道回到Pyspark时,它不知道我的隐藏类对象的任何信息,它只知道有一个StopWordsRemover实例。理想情况下,最好是子类化管道和PipelineModel,但遗憾的是,这又让我们试图序列化Python对象。为了解决这个问题,我创建了一个PysparkPipelineWrapper,它接受管道或PipelineModel,并只扫描阶段,查找stopwords列表中的编码ID(记住,这只是我的Python对象的pickled字节),以便将列表解包到我的实例中,并将其存回来。下面的代码展示了所有这些的工作原理。

对于任何自定义Pyspark评估器、转换器和模型,只需继承Identifiable、PysparkReaderWriter、MLReadable和MLWritable即可。然后在加载管道和PipelineModel时,通过PysparkPipelineWrapper.unwrap(pipeline)传递。

此方法并未解决在Java或Scala中使用Pyspark代码的问题,但至少我们可以保存和加载自定义Pyspark评估器、转换器和模型,并与管道API一起使用。

import dill
from pyspark.ml import Transformer, Pipeline, PipelineModel
from pyspark.ml.param import Param, Params
from pyspark.ml.util import Identifiable, MLReadable, MLWritable, JavaMLReader, JavaMLWriter
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.wrapper import JavaParams
from pyspark.context import SparkContext
from pyspark.sql import Row

class PysparkObjId(object):
    """
    A class to specify constants used to idenify and setup python 
    Estimators, Transformers and Models so they can be serialized on there
    own and from within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkObjId, self).__init__()

    @staticmethod
    def _getPyObjId():
        return '4c1740b00d3c4ff6806a1402321572cb'

    @staticmethod
    def _getCarrierClass(javaName=False):
        return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover

class PysparkPipelineWrapper(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineWrapper, self).__init__()

    @staticmethod
    def unwrap(pipeline):
        if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
            raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

        stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineWrapper.unwrap(stage)
            if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                swords = stage.getStopWords()[:-1] # strip the id
                lst = [chr(int(d)) for d in swords]
                dmp = ''.join(lst)
                py_obj = dill.loads(dmp)
                stages[i] = py_obj

        if isinstance(pipeline, Pipeline):
            pipeline.setStages(stages)
        else:
            pipeline.stages = stages
        return pipeline

class PysparkReaderWriter(object):
    """
    A mixin class so custom pyspark Estimators, Transformers and Models may
    support saving and loading directly or be saved within a Pipline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(PysparkObjId._getCarrierClass())

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

    @classmethod
    def _from_java(cls, java_obj):
        """
        Get the dumby the stopwords that are the characters of the dills dump plus our guid
        and convert, via dill, back to our python instance.
        """
        swords = java_obj.getStopWords()[:-1] # strip the id
        lst = [chr(int(d)) for d in swords] # convert from string integer list to bytes
        dmp = ''.join(lst)
        py_obj = dill.loads(dmp)
        return py_obj

    def _to_java(self):
        """
        Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
        Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
        :return: Java object equivalent to this instance.
        """
        dmp = dill.dumps(self)
        pylist = [str(ord(d)) for d in dmp] # convert byes to string integer list
        pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
        sc = SparkContext._active_spark_context
        java_class = sc._gateway.jvm.java.lang.String
        java_array = sc._gateway.new_array(java_class, len(pylist))
        for i in xrange(len(pylist)):
            java_array[i] = pylist[i]
        _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

class HasFake(Params):
    def __init__(self):
        super(HasFake, self).__init__()
        self.fake = Param(self, "fake", "fake param")

    def getFake(self):
        return self.getOrDefault(self.fake)

class MockTransformer(Transformer, HasFake, Identifiable):
    def __init__(self):
        super(MockTransformer, self).__init__()
        self.dataset_count = 0

    def _transform(self, dataset):
        self.dataset_count = dataset.count()
        return dataset

class MyTransformer(MockTransformer, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
    def __init__(self):
        super(MyTransformer, self).__init__()

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Alice', age=5, height=80), Row(name='Alice', age=10, height=80)]).toDF()
    return df

def test1():
    trA = MyTransformer()
    trA.dataset_count = 999
    print trA.dataset_count
    trA.save('test.trans')
    trB = MyTransformer.load('test.trans')
    print trB.dataset_count

def test2():
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA])
    print type(pipeA)
    pipeA.save('testA.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(Pipeline.load('testA.pipe'))
    stagesAA = pipeAA.getStages()
    trAA = stagesAA[0]
    print trAA.dataset_count

def test3():
    dfA = make_a_dataframe(sc)
    trA = MyTransformer()
    pipeA = Pipeline(stages=[trA]).fit(dfA)
    print type(pipeA)
    pipeA.save('testB.pipe')
    pipeAA = PysparkPipelineWrapper.unwrap(PipelineModel.load('testB.pipe'))
    stagesAA = pipeAA.stages
    trAA = stagesAA[0]
    print trAA.dataset_count
    dfB = pipeAA.transform(dfA)
    dfB.show()

1
太聪明了!谢谢分享。希望Spark能尽快添加保存和加载自定义全Python MLlib管道阶段的功能。许多数据科学家使用Python而不是Scala/Java,实际上重新拟合/训练管道模型并不切实际。 - snark
我无法在Python 2.7.12和Spark 2.2.0(以及Zeppelin 0.7.3)中使其工作。test1()失败并显示:PicklingError: Can't pickle <class 'MyTransformer'>: it's not found as __builtin__.MyTransformer - snark
@snark 我猜这在早期版本里对你起作用了? - dmbaker
我承认直到现在我从未尝试过使用它:)。尽管在https://databricks.com/blog/2017/08/30/developing-custom-machine-learning-algorithms-in-pyspark.html上有一些有希望的迹象,但看起来https://issues.apache.org/jira/browse/SPARK-17025还没有准备好(而且无论如何我都必须等待适当的Spark版本)。 - snark
1
@AdrienForbu 不行。这些技术是为了在Pyspark管道中保存自定义估算器和转换器,因为在使用Pyspark时,这方面的支持并不像在Scala中那样好。同样,这些技术严格适用于Pyspark,不能在Scala中使用。 - dmbaker
显示剩余2条评论

4

我无法在Spark 2.2.0上使用Python 2获得@dmbaker的巧妙解决方案,我一直遇到了pickling错误。经过几次尝试,我通过修改他(她?)的想法,将参数值以字符串形式写入和读取到StopWordsRemover的停用词中,最终实现了可行的解决方案。

如果您想保存和加载自己的估算器或变换器,这是您需要的基础类:

from pyspark import SparkContext
from pyspark.ml.feature import StopWordsRemover
from pyspark.ml.util import Identifiable, MLWritable, JavaMLWriter, MLReadable, JavaMLReader
from pyspark.ml.wrapper import JavaWrapper, JavaParams

class PysparkReaderWriter(Identifiable, MLReadable, MLWritable):
    """
    A base class for custom pyspark Estimators and Models to support saving and loading directly
    or within a Pipeline or PipelineModel.
    """
    def __init__(self):
        super(PysparkReaderWriter, self).__init__()

    @staticmethod
    def _getPyObjIdPrefix():
        return "_ThisIsReallyA_"

    @classmethod
    def _getPyObjId(cls):
        return PysparkReaderWriter._getPyObjIdPrefix() + cls.__name__

    def getParamsAsListOfStrings(self):
        raise NotImplementedError("PysparkReaderWriter.getParamsAsListOfStrings() not implemented for instance: %r" % self)

    def write(self):
        """Returns an MLWriter instance for this ML instance."""
        return JavaMLWriter(self)

    def _to_java(self):
        # Convert all our parameters to strings:
        paramValuesAsStrings = self.getParamsAsListOfStrings()

        # Append our own type-specific id so PysparkPipelineLoader can detect this algorithm when unwrapping us.
        paramValuesAsStrings.append(self._getPyObjId())

        # Convert the parameter values to a Java array:
        sc = SparkContext._active_spark_context
        java_array = JavaWrapper._new_java_array(paramValuesAsStrings, sc._gateway.jvm.java.lang.String)

        # Create a Java (Scala) StopWordsRemover and give it the parameters as its stop words.
        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid)
        _java_obj.setStopWords(java_array)
        return _java_obj

    @classmethod
    def _from_java(cls, java_obj):
        # Get the stop words, ignoring the id at the end:
        stopWords = java_obj.getStopWords()[:-1]
        return cls.createAndInitialisePyObj(stopWords)

    @classmethod
    def createAndInitialisePyObj(cls, paramsAsListOfStrings):
        raise NotImplementedError("PysparkReaderWriter.createAndInitialisePyObj() not implemented for type: %r" % cls)

    @classmethod
    def read(cls):
        """Returns an MLReader instance for our clarrier class."""
        return JavaMLReader(StopWordsRemover)

    @classmethod
    def load(cls, path):
        """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
        swr_java_obj = cls.read().load(path)
        return cls._from_java(swr_java_obj)

您自己的Pyspark算法必须继承PysparkReaderWriter并重写getParamsAsListOfStrings()方法,该方法将您的参数保存到字符串列表中。您的算法还必须重写createAndInitialisePyObj()方法,以将字符串列表转换回参数。在幕后,参数被转换为StopWordsRemover使用的停用词。

示例估算器有3个不同类型的参数:

from pyspark.ml.param.shared import Param, Params, TypeConverters
from pyspark.ml.base import Estimator

class MyEstimator(Estimator, PysparkReaderWriter):

def __init__(self):
    super(MyEstimator, self).__init__()

# 3 sample parameters, deliberately of different types:
stringParam = Param(Params._dummy(), "stringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)

def setStringParam(self, value):
    return self._set(stringParam=value)

def getStringParam(self):
    return self.getOrDefault(self.stringParam)

listOfStringsParam = Param(Params._dummy(), "listOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)

def setListOfStringsParam(self, value):
    return self._set(listOfStringsParam=value)

def getListOfStringsParam(self):
    return self.getOrDefault(self.listOfStringsParam)

intParam = Param(Params._dummy(), "intParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)

def setIntParam(self, value):
    return self._set(intParam=value)

def getIntParam(self):
    return self.getOrDefault(self.intParam)

def _fit(self, dataset):
    model = MyModel()
    # Just some changes to verify we can modify the model (and also it's something we can expect to see when restoring it later):
    model.setAnotherStringParam(self.getStringParam() + " World!")
    model.setAnotherListOfStringsParam(self.getListOfStringsParam() + ["E", "F"])
    model.setAnotherIntParam(self.getIntParam() + 10)
    return model

def getParamsAsListOfStrings(self):
    paramValuesAsStrings = []
    paramValuesAsStrings.append(self.getStringParam()) # Parameter is already a string
    paramValuesAsStrings.append(','.join(self.getListOfStringsParam())) # ...convert from a list of strings
    paramValuesAsStrings.append(str(self.getIntParam())) # ...convert from an int
    return paramValuesAsStrings

@classmethod
def createAndInitialisePyObj(cls, paramsAsListOfStrings):
    # Convert back into our parameters. Make sure you do this in the same order you saved them!
    py_obj = cls()
    py_obj.setStringParam(paramsAsListOfStrings[0])
    py_obj.setListOfStringsParam(paramsAsListOfStrings[1].split(","))
    py_obj.setIntParam(int(paramsAsListOfStrings[2]))
    return py_obj

示例模型(也是变形金刚),它有三个不同的参数:

from pyspark.ml.base import Model

class MyModel(Model, PysparkReaderWriter):

    def __init__(self):
        super(MyModel, self).__init__()

    # 3 sample parameters, deliberately of different types:
    anotherStringParam = Param(Params._dummy(), "anotherStringParam", "A dummy string parameter", typeConverter=TypeConverters.toString)

    def setAnotherStringParam(self, value):
        return self._set(anotherStringParam=value)

    def getAnotherStringParam(self):
        return self.getOrDefault(self.anotherStringParam)

    anotherListOfStringsParam = Param(Params._dummy(), "anotherListOfStringsParam", "A dummy list of strings.", typeConverter=TypeConverters.toListString)

    def setAnotherListOfStringsParam(self, value):
        return self._set(anotherListOfStringsParam=value)

    def getAnotherListOfStringsParam(self):
        return self.getOrDefault(self.anotherListOfStringsParam)

    anotherIntParam = Param(Params._dummy(), "anotherIntParam", "A dummy int parameter.", typeConverter=TypeConverters.toInt)

    def setAnotherIntParam(self, value):
        return self._set(anotherIntParam=value)

    def getAnotherIntParam(self):
        return self.getOrDefault(self.anotherIntParam)

    def _transform(self, dataset):
        # Dummy transform code:
        return dataset.withColumn('age2', dataset.age + self.getAnotherIntParam())

    def getParamsAsListOfStrings(self):
        paramValuesAsStrings = []
        paramValuesAsStrings.append(self.getAnotherStringParam()) # Parameter is already a string
        paramValuesAsStrings.append(','.join(self.getAnotherListOfStringsParam())) # ...convert from a list of strings
        paramValuesAsStrings.append(str(self.getAnotherIntParam())) # ...convert from an int
        return paramValuesAsStrings

    @classmethod
    def createAndInitialisePyObj(cls, paramsAsListOfStrings):
        # Convert back into our parameters. Make sure you do this in the same order you saved them!
        py_obj = cls()
        py_obj.setAnotherStringParam(paramsAsListOfStrings[0])
        py_obj.setAnotherListOfStringsParam(paramsAsListOfStrings[1].split(","))
        py_obj.setAnotherIntParam(int(paramsAsListOfStrings[2]))
        return py_obj

以下是一个示例测试用例,展示了如何保存和加载模型。对于估计器来说,也是类似的,为了简洁起见,我省略了它。
def createAModel():
    m = MyModel()
    m.setAnotherStringParam("Boo!")
    m.setAnotherListOfStringsParam(["P", "Q", "R"])
    m.setAnotherIntParam(77)
    return m

def testSaveLoadModel():
    modA = createAModel()
    print(modA.explainParams())

    savePath = "/whatever/path/you/want"
    #modA.save(savePath) # Can't overwrite, so...
    modA.write().overwrite().save(savePath)

    modB = MyModel.load(savePath)
    print(modB.explainParams())

testSaveLoadModel()

输出:

anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: ['P', 'Q', 'R'])
anotherStringParam: A dummy string parameter (current: Boo!)
anotherIntParam: A dummy int parameter. (current: 77)
anotherListOfStringsParam: A dummy list of strings. (current: [u'P', u'Q', u'R'])
anotherStringParam: A dummy string parameter (current: Boo!)

请注意参数已经以unicode字符串的形式返回。这可能对您在_transform()(或估算器的_fit())中实现的基础算法有所影响。因此,请注意这一点。

最后,因为Scala算法背后是一个真正的StopWordsRemover,所以当从磁盘加载PipelinePipelineModel时,需要将其解包回到您自己的类中。以下是执行此操作的实用程序类

from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import StopWordsRemover

class PysparkPipelineLoader(object):
    """
    A class to facilitate converting the stages of a Pipeline or PipelineModel
    that were saved from PysparkReaderWriter.
    """
    def __init__(self):
        super(PysparkPipelineLoader, self).__init__()

    @staticmethod
    def unwrap(thingToUnwrap, customClassList):
        if not (isinstance(thingToUnwrap, Pipeline) or isinstance(thingToUnwrap, PipelineModel)):
            raise TypeError("Cannot recognize an object of type %s." % type(thingToUnwrap))

        stages = thingToUnwrap.getStages() if isinstance(thingToUnwrap, Pipeline) else thingToUnwrap.stages
        for i, stage in enumerate(stages):
            if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                stages[i] = PysparkPipelineLoader.unwrap(stage)

            if isinstance(stage, StopWordsRemover) and stage.getStopWords()[-1].startswith(PysparkReaderWriter._getPyObjIdPrefix()):

                lastWord = stage.getStopWords()[-1] 
                className = lastWord[len(PysparkReaderWriter._getPyObjIdPrefix()):]

                stopWords = stage.getStopWords()[:-1] # Strip the id

                # Create and initialise the appropriate class:
                py_obj = None
                for clazz in customClassList:
                    if clazz.__name__ == className:
                        py_obj = clazz.createAndInitialisePyObj(stopWords)

                if py_obj is None:
                    raise TypeError("I don't know how to create an instance of type: %s" % className)

                stages[i] = py_obj

        if isinstance(thingToUnwrap, Pipeline):
            thingToUnwrap.setStages(stages)
        else:
            # PipelineModel
            thingToUnwrap.stages = stages
        return thingToUnwrap

测试保存和加载管道:

def testSaveAndLoadUnfittedPipeline():
    estA = createAnEstimator()
    #print(estA.explainParams())
    pipelineA = Pipeline(stages=[estA])
    savePath = "/whatever/path/you/want"
    #pipelineA.save(savePath) # Can't overwrite, so...
    pipelineA.write().overwrite().save(savePath)

    pipelineReloaded = PysparkPipelineLoader.unwrap(Pipeline.load(savePath), [MyEstimator])
    estB = pipelineReloaded.getStages()[0]
    print(estB.explainParams())

testSaveAndLoadUnfittedPipeline()

输出:

intParam: A dummy int parameter. (current: 42)
listOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D'])
stringParam: A dummy string parameter (current: Hello)

测试保存和加载管道模型:

from pyspark.sql import Row

def make_a_dataframe(sc):
    df = sc.parallelize([Row(name='Alice', age=5, height=80), Row(name='Bob', age=7, height=85), Row(name='Chris', age=10, height=90)]).toDF()
    return df

def testSaveAndLoadPipelineModel():
    dfA = make_a_dataframe(sc)
    estA = createAnEstimator()
    #print(estA.explainParams())
    pipelineModelA = Pipeline(stages=[estA]).fit(dfA)
    savePath = "/whatever/path/you/want"
    #pipelineModelA.save(savePath) # Can't overwrite, so...
    pipelineModelA.write().overwrite().save(savePath)

    pipelineModelReloaded = PysparkPipelineLoader.unwrap(PipelineModel.load(savePath), [MyModel])
    modB = pipelineModelReloaded.stages[0]
    print(modB.explainParams())

    dfB = pipelineModelReloaded.transform(dfA)
    dfB.show()

testSaveAndLoadPipelineModel()

输出:

anotherIntParam: A dummy int parameter. (current: 52)
anotherListOfStringsParam: A dummy list of strings. (current: [u'A', u'B', u'C', u'D', u'E', u'F'])
anotherStringParam: A dummy string parameter (current: Hello World!)
+---+------+-----+----+
|age|height| name|age2|
+---+------+-----+----+
|  5|    80|Alice|  57|
|  7|    85|  Bob|  59|
| 10|    90|Chris|  62|
+---+------+-----+----+

在展开管道或管道模型时,您需要传递一个类列表,这些类与您自己的 masquerading 为 StopWordsRemover 对象的 pyspark 算法对应。保存的对象中最后一个停用词用于标识您自己的类名,然后调用 createAndInitialisePyObj() 创建一个实例并使用其余停用词初始化其参数。

可以进行各种改进,但希望这将使您能够保存和加载自定义估算器和转换器,无论是在管道内部还是外部,直到 SPARK-17025 得到解决并对您可用。


2

@dmbaker的解决方案对我没用。我认为这是因为Python版本(2.x与3.x不同)。我对他的解决方案进行了一些更新,现在它可以在Python 3上运行。我的设置如下:

  • python: 3.6.3
  • spark: 2.2.1
  • dill: 0.2.7.1
class PysparkObjId(object):
   """
   A class to specify constants used to idenify and setup python
   Estimators, Transformers and Models so they can be serialized on there
   own and from within a Pipline or PipelineModel.
   """
        def __init__(self):
            super(PysparkObjId, self).__init__()

        @staticmethod
        def _getPyObjId():
            return '4c1740b00d3c4ff6806a1402321572cb'

        @staticmethod
        def _getCarrierClass(javaName=False):
            return 'org.apache.spark.ml.feature.StopWordsRemover' if javaName else StopWordsRemover


    class PysparkPipelineWrapper(object):
        """
        A class to facilitate converting the stages of a Pipeline or PipelineModel
        that were saved from PysparkReaderWriter.
        """
        def __init__(self):
            super(PysparkPipelineWrapper, self).__init__()

        @staticmethod
        def unwrap(pipeline):
            if not (isinstance(pipeline, Pipeline) or isinstance(pipeline, PipelineModel)):
                raise TypeError("Cannot recognize a pipeline of type %s." % type(pipeline))

            stages = pipeline.getStages() if isinstance(pipeline, Pipeline) else pipeline.stages
            for i, stage in enumerate(stages):
                if (isinstance(stage, Pipeline) or isinstance(stage, PipelineModel)):
                    stages[i] = PysparkPipelineWrapper.unwrap(stage)
                if isinstance(stage, PysparkObjId._getCarrierClass()) and stage.getStopWords()[-1] == PysparkObjId._getPyObjId():
                    swords = stage.getStopWords()[:-1] # strip the id
                    # convert stop words to int
                    swords = [int(d) for d in swords]
                    # get the byte value of all ints
                    lst = [x.to_bytes(length=1, byteorder='big') for x in
                           swords]  # convert from string integer list to bytes
                    # return the first byte and concatenates all the others
                    dmp = lst[0]
                    for byte_counter in range(1, len(lst)):
                        dmp = dmp + lst[byte_counter]

                    py_obj = dill.loads(dmp)
                    stages[i] = py_obj

            if isinstance(pipeline, Pipeline):
                pipeline.setStages(stages)
            else:
                pipeline.stages = stages
            return pipeline


    class PysparkReaderWriter(object):
        """
        A mixin class so custom pyspark Estimators, Transformers and Models may
        support saving and loading directly or be saved within a Pipline or PipelineModel.
        """
        def __init__(self):
            super(PysparkReaderWriter, self).__init__()

        def write(self):
            """Returns an MLWriter instance for this ML instance."""
            return JavaMLWriter(self)

        @classmethod
        def read(cls):
            """Returns an MLReader instance for our clarrier class."""
            return JavaMLReader(PysparkObjId._getCarrierClass())

        @classmethod
        def load(cls, path):
            """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
            swr_java_obj = cls.read().load(path)
            return cls._from_java(swr_java_obj)

        @classmethod
        def _from_java(cls, java_obj):
            """
            Get the dumby the stopwords that are the characters of the dills dump plus our guid
            and convert, via dill, back to our python instance.
            """
            swords = java_obj.getStopWords()[:-1] # strip the id
            lst = [x.to_bytes(length=1, byteorder='big') for x in swords] # convert from string integer list to bytes
            dmp = lst[0]
            for i in range(1, len(lst)):
                dmp = dmp + lst[i]
            py_obj = dill.loads(dmp)
            return py_obj

        def _to_java(self):
            """
            Convert this instance to a dill dump, then to a list of strings with the unicode integer values of each character.
            Use this list as a set of dumby stopwords and store in a StopWordsRemover instance
            :return: Java object equivalent to this instance.
            """
            dmp = dill.dumps(self)

            pylist = [str(int(d)) for d in dmp] # convert bytes to string integer list
            pylist.append(PysparkObjId._getPyObjId()) # add our id so PysparkPipelineWrapper can id us.
            sc = SparkContext._active_spark_context
            java_class = sc._gateway.jvm.java.lang.String
            java_array = sc._gateway.new_array(java_class, len(pylist))
            for i in range(len(pylist)):
                java_array[i] = pylist[i]
            _java_obj = JavaParams._new_java_obj(PysparkObjId._getCarrierClass(javaName=True), self.uid)
            _java_obj.setStopWords(java_array)

            return _java_obj


    class HasFake(Params):
        def __init__(self):
            super(HasFake, self).__init__()
            self.fake = Param(self, "fake", "fake param")

        def getFake(self):
            return self.getOrDefault(self.fake)


    class CleanText(Transformer, HasInputCol, HasOutputCol, Identifiable, PysparkReaderWriter, MLReadable, MLWritable):
        @keyword_only
        def __init__(self, inputCol=None, outputCol=None):
            super(CleanText, self).__init__()
            kwargs = self._input_kwargs
            self.setParams(**kwargs)   

2

与@dmbaker提供的可行答案类似,我将自定义转换器Aggregator包装在内置的Spark转换器中,在此示例中为Binarizer,但我相信您也可以从其他转换器继承。这使得我的自定义转换器能够继承序列化所需的方法。

from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, Binarizer
from pyspark.ml.regression import LinearRegression    

class Aggregator(Binarizer):
    """A huge hack to allow serialization of custom transformer."""

    def transform(self, input_df):
        agg_df = input_df\
            .groupBy('channel_id')\
            .agg({
                'foo': 'avg',
                'bar': 'avg',
            })\
            .withColumnRenamed('avg(foo)', 'avg_foo')\
            .withColumnRenamed('avg(bar)', 'avg_bar') 
        return agg_df

# Create pipeline stages.
aggregator = Aggregator()
vector_assembler = VectorAssembler(...)
linear_regression = LinearRegression()

# Create pipeline.
pipeline = Pipeline(stages=[aggregator, vector_assembler, linear_regression])

# Train.
pipeline_model = pipeline.fit(input_df)

# Save model file to S3.
pipeline_model.save('s3n://example')

我可能错了,但是我猜测如果你的定制算法有一个或多个不在继承算法中的参数,那么这个快捷方式就不起作用了... - snark
我尝试了你的方法,虽然保存操作有效,但是当我从中读取我的模型并使用transform时,出现了错误。你是否遇到过这个问题? - Scratch'N'Purr
我忘了。老实说,我们最终选择了scikit-learn。Spark ML有太多的开销,不值得,而且你可以使用一些技巧使“大数据”适应于scikit-learn,它具有更少的开销、更快的预测和启动,并且序列化为一个文件。 - Steve Tjoa

0

我编写了一些基类来使这个过程更容易。基本上,我将代码和初始化的所有复杂性抽象成一些基类,这些基类公开了一个更简单的API来构建自定义类。这包括解决序列化/反序列化问题以及保存和加载SparkML对象。然后,您可以在__init__transform/fit函数中集中精力。您可以在此处找到带有示例的完整说明。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接