谷歌云ML引擎scikit-learn预测概率'predict_proba()'。

12

Google Cloud ML-engine支持部署scikit-learn Pipeline对象。例如,文本分类Pipeline可能如下所示:

classifier = Pipeline([
('vect', CountVectorizer()), 
('clf', naive_bayes.MultinomialNB())])

分类器可以进行训练,
classifier.fit(train_x, train_y)

然后分类器可以上传到Google Cloud Storage,

model = 'model.joblib'
joblib.dump(classifier, model)
model_remote_path = os.path.join('gs://', bucket_name, datetime.datetime.now().strftime('model_%Y%m%d_%H%M%S'), model)
subprocess.check_call(['gsutil', 'cp', model, model_remote_path], stderr=sys.stdout)

然后可以通过Google Cloud控制台或编程方式创建ModelVersion,将'model.joblib'文件链接到Version中。使用已部署的模型predict端点调用分类器以预测新数据。
ml = discovery.build('ml','v1')
project_id = 'projects/{}/models/{}'.format(project_name, model_name)
if version_name is not None:
    project_id += '/versions/{}'.format(version_name)
request_dict = {'instances':['Test data']}
ml_request = ml.projects().predict(name=project_id, body=request_dict).execute()

谷歌云ML引擎调用分类器的predict函数并返回预测类。然而,我想能够返回置信度得分。通常,可以通过调用分类器的predict_proba函数来实现这一点,但似乎没有更改所调用函数的选项。我的问题是:在使用谷歌云ML引擎时,是否可能返回scikit-learn分类器的置信度得分?如果不行,您有什么其他建议吗? 更新: 我找到了一个hacky解决方案。它涉及重写分类器的predict函数,将其替换为自己的predict_proba函数。
nb = naive_bayes.MultinomialNB()
nb.predict = nb.predict_proba
classifier = Pipeline([
('vect', CountVectorizer()), 
('clf', nb)])

令人惊讶的是,这个方法可行。如果有更好的解决方案,请告诉我。

更新:谷歌发布了一个新功能(目前处于测试版),名为自定义预测例程。这允许您在预测请求到来时定义要运行的代码。它增加了更多的代码,但肯定比较不那么hacky。

1个回答

1
您正在使用的 ML Engine API 只有 predict 方法,如 文档 中所示,因此它只会进行预测(除非您使用提到的 hack 强制它执行其他操作)。
如果您想对已训练的模型进行其他操作,则需要加载并正常使用它。如果要使用存储在云存储中的模型,则可以执行以下操作:
from google.cloud import storage
from sklearn.externals import joblib

bucket_name = "<BUCKET_NAME>"
gs_model = "path/to/model.joblib"  # path in your Cloud Storage bucket
local_model = "/path/to/model.joblib"  # path in your local machine

client = storage.Client()
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(gs_model)
blob.download_to_filename(local_model)

model = joblib.load(local_model)
model.predict_proba(test_data)

感谢您的回复,rilla。我已经根据 Google 最新发布的一个新功能更新了原帖,并提供了更简洁的解决方案。 - Alex Morgan

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接