将StandardScaler()模型保存以便在新数据集上使用。

26

我该如何在Sklearn中保存StandardScaler()模型? 我需要将模型应用到新数据上进行预测,而不希望反复加载训练数据以便StandardScaler学习并应用。

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

#standardizing after splitting
X_train, X_test, y_train, y_test = train_test_split(data, target)
sc = StandardScaler()
X_train_std = sc.fit_transform(X_train)
X_test_std = sc.transform(X_test)

相关问题:https://dev59.com/81gQ5IYBdhLWcg3w7obY - Stephen
3个回答

38

您可以使用joblib的dump函数来保存标准化缩放模型。以下是一个完整的示例供参考。

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

data, target = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(data, target)

sc = StandardScaler()
X_train_std = sc.fit_transform(X_train)

如果你想保存Sc标准调用者,请使用以下方法。
from sklearn.externals.joblib import dump, load
dump(sc, 'std_scaler.bin', compress=True)

这将创建文件std_scaler.bin并保存sklearn模型。

稍后读取模型时,请使用load

sc=load('std_scaler.bin')

注意: sklearn.externals.joblib 已弃用。请安装并使用纯粹的 joblib


1
如何在加载模型数据后进行预测 - Amarnath Reddy Surapureddy
使用sc.transform(X)将缩放应用于新数据。 - Frank_Coumans

20

或者如果您喜欢使用pickle:

import pickle
with open('file/path/scaler.pkl','wb') as f:
    pickle.dump(sc, f)
with open('file/path/scaler.pkl','rb') as f:
    sc = pickle.load(f)

2
这应该是被接受的答案。虽然,我更喜欢使用 with open().. 而不是依赖 gc 来关闭文件。 - Niko Pasanen

0

你可以简单地记住 mean_ 和 scale_。

  1. 因此,在你拟合(计算均值和标准差)你的 StandardScaler 之后,打印出 mean 和 scale。
scaler = StandardScaler()
X = scaler.fit_transform(X)
print("Scaler mean: ", scaler.mean_)
print("Scaler scale: ", scaler.scale_)

在我的例子中,输出看起来像这样: 缩放器平均值:[9.52058421e-01 -6.98286619e-03 -4.14269899e-01 -1.40126971e-01 -8.17856250e+00 5.50322867e+01] 缩放器比例:[0.6635306 0.29163553 0.65517668 23.05331473 36.66616542 43.53057184]

  1. 当您需要再次使用缩放器进行预测时(scaler1是新的缩放器,以确保不使用旧的缩放器):
scaler1 = StandardScaler()
scaler1.mean_ = np.array([ 9.52058421e-01, -6.98286619e-03, -4.14269899e-01, -1.40126971e-01, -8.17856250e+00, 5.50322867e+01])
scaler1.scale_ = np.array([ 0.6635306, 0.29163553, 0.65517668, 23.05331473, 36.66616542, 43.53057184]) 

# then use it to transform your data
X = scaler1.transform(X)

在我的测试中,结果是相同的。注意:不要忘记在np.array([ ... , ...])中设置逗号。
干杯!

1
请不要这样编码,并向他人推荐。这是一个容易出错的hacky解决方案。 - AlexK
在Python中,当成员变量包含下划线时,表示"不要触碰它"。 - interoception

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接