sklearn
估计器实现了一些方法,使您可以轻松保存估计器的相关训练属性。一些估计器自己实现了__getstate__
方法,但其他估计器(例如GMM
)只是使用基本实现,它仅仅保存对象的内部字典:
def __getstate__(self):
try:
state = super(BaseEstimator, self).__getstate__()
except AttributeError:
state = self.__dict__.copy()
if type(self).__module__.startswith('sklearn.'):
return dict(state.items(), _sklearn_version=__version__)
else:
return state
推荐的将模型保存到磁盘的方法是使用
pickle
模块:
from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
pickle.dump(model,f)
然而,您应该保存额外的数据,以便将来可以重新训练模型,否则会遭受严重后果(例如被锁定在旧版本的sklearn中)。
从
文档中可知:
为了能够使用未来版本的scikit-learn重建类似的模型,必须保存一些额外的元数据和已捕获的模型:
- 训练数据,例如指向不可变快照的引用
- 用于生成模型的Python源代码
- scikit-learn的版本及其依赖项
- 在培训数据上获得的交叉验证分数
这对于依赖于Cython编写的
tree.pyx
模块(例如
IsolationForest
)的集成估计器尤其如此,因为它会创建与实现的耦合,而在sklearn版本之间不保证稳定。它曾经出现过向后不兼容的变化。
如果您的模型变得非常大,并且加载变得麻烦,您也可以使用更高效的
joblib
。从文档中可以看到:
引用:
特别是对于scikit的情况,使用joblib的
pickle
替代方案(
joblib.dump
和
joblib.load
)可能更有趣,它对于内部带有大型numpy数组的对象更有效,通常是适配的scikit-learn估计器的情况,但只能将其pickle到磁盘而不能pickle到字符串中。