在谷歌自动机器学习中创建模型之后,我们可以使用提供的 Python 代码进行预测。以下是代码:
import sys
from google.cloud import automl_v1beta1
from google.cloud.automl_v1beta1.proto import service_pb2
def get_prediction(content, project_id, model_id):
prediction_client = automl_v1beta1.PredictionServiceClient()
name = 'projects/{}/locations/us-central1/models/{}'.format(project_id, model_id)
payload = {'image': {'image_bytes': content }}
params = {}
request = prediction_client.predict(name, payload, params)
return request # waits till request is returned
if __name__ == '__main__':
file_path = sys.argv[1]
project_id = sys.argv[2]
model_id = sys.argv[3]
with open(file_path, 'rb') as ff:
content = ff.read()
print get_prediction(content, project_id, model_id)
我意识到它只会打印出得分高于阈值value = 0.5
的检测结果。 示例输出:
payload {
classification {
score: 0.562688529491
}
display_name: "dog"
}
如何打印出得分低于0.5的其他检测结果(例如将阈值改为0.3)?
params = {"score_threshold": score_threshold}
- West