云ML引擎和Scikit-Learn:“LatentDirichletAllocation”对象没有“predict”属性。

3
我正在实现简单的Scikit-Learn Pipeline来在Google Cloud ML Engine上执行LatentDirichletAllocation。目标是从新数据中预测主题。以下是生成Pipeline的代码:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_20newsgroups

dataset = fetch_20newsgroups(shuffle=True, random_state=1,
                             remove=('headers', 'footers', 'quotes'))
train, test = train_test_split(dataset.data[:2000])

pipeline = Pipeline([
    ('CountVectorizer', CountVectorizer(
        max_df          = 0.95,
        min_df          = 2,
        stop_words      = 'english')),
    ('LatentDirichletAllocation', LatentDirichletAllocation(
        n_components    = 10,
        learning_method ='online'))
])

pipeline.fit(train)

现在(如果我理解正确),为了预测测试数据的主题,我可以运行以下命令:
pipeline.transform(test)

然而,当将管道上传到Google Cloud Storage并尝试使用它在Google Cloud ML Engine上生成本地预测时,我会收到错误提示,指出LatentDirichletAllocation没有predict属性。
gcloud ml-engine local predict \
    --model-dir=$MODEL_DIR \
    --json-instances $INPUT_FILE \
    --framework SCIKIT_LEARN
...
"Exception during sklearn prediction: " + str(e)) cloud.ml.prediction.prediction_utils.PredictionError: Failed to run the provided model: Exception during sklearn prediction: 'LatentDirichletAllocation' object has no attribute 'predict' (Error code: 2)

缺乏预测方法也可以从文档中看出来,所以我猜这不是解决这个问题的方法。 http://scikit-learn.org/stable/modules/generated/sklearn.decomposition.LatentDirichletAllocation.html 现在的问题是:该怎么办?如何在Google Cloud ML Engine的Scikit-Learn管道中使用LatentDirichletAllocation(或类似方法)?

有趣的情况...事实是,CountVectorizer也没有predict方法(它有一个transform方法),但它不会产生错误... - desertnaut
2
Pipeline文档中,我了解到predict仅适用于最后一个估算器。这就是为什么CountVectorizer不会产生错误的原因。http://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html#sklearn.pipeline.Pipeline.predict - pipo
免责声明:我不是 Python 专家...我深入研究了源代码,发现 BaseEstimator 实际上没有 predict() 方法(LatentDirichletAllocation 本身也没有)。但是,BaseEstimatormixins 中确实提到了 predict() 方法。因此,很难看出 predict() 是如何/在哪里实现的。那么,Google appEngine 返回的错误是否有效呢? - WestCoastProjects
2
@pipo。根据我下面的回答,目前还不支持此功能,但我们很快会推出一些可能的解决方法。您是否愿意通过电子邮件讨论您的用例?如果是,请发送电子邮件至cloudml-feedback@并引用此帖子。 - rhaertel80
1个回答

3

目前,pipeline 中的最后一个评估器必须实现 predict 方法。


有关解决方法方面是否有任何更新?我也很感兴趣。 - Daniel
2
我们目前正在进行Alpha测试的解决方案。如果您愿意尝试并提供反馈,我们将不胜感激。请通过cloudml-feedback@google.com联系我们,以获取有关如何入门的信息。 - rhaertel80

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接