如何在Python中正确加载CatBoost的预训练模型

13

我已经训练好了CatBoostClassifier用于解决我的分类任务。现在,我需要保存模型,并在另一个应用程序中使用它进行预测。为此,我通过save_model方法保存了模型,并通过load_model方法恢复了它。

然而,每次我在恢复的模型中调用predict时都会出现错误:

CatboostError: There is no trained model to use predict(). Use fit() to train model. Then use predict().

看起来我需要重新训练我的模型,而我需要恢复预训练的模型并仅用于预测。

我在这里做错了什么?加载模型进行预测是否有特殊的方法?

我的训练过程如下:

model = CatBoostClassifier(
    custom_loss=['Accuracy'],
    random_seed=42,
    logging_level='Silent',
    loss_function='MultiClass')

model.fit(
    x_train, 
    y_train,
    cat_features=None,
    eval_set=(x_validation, y_validation),
    plot=True)

...

model.save("model.cbm")

我使用以下代码恢复模型:

model = CatBoostClassifier(
    custom_loss=['Accuracy'],
    random_seed=42,
    logging_level='Silent',
    loss_function='MultiClass')
model.load_model("model.cbm")

...


predict = self.model.predict(inputs)
2个回答

28
# After you train the model using fit(), save like this - 
model.save_model('model_name')    # extension not required.

# And then, later load - 
from catboost import CatBoostClassifier
model = CatBoostClassifier()      # parameters not required.
model.load_model('model_name')

# Now, try predict().

0
几个小时后,我意外地找到了解决方案。模型加载是在外部Python模块中实现的,然后导入到Jupyter Notebook中。结果发现我只需要重新启动Jupyter内核即可。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接