最近,我一直在使用Grid Search交叉验证(sklearn GridSearchCV)对Keras和Tensorflow后端的超参数进行调整。当我的模型完成调整后,我尝试保存GridSearchCV对象以供日后使用,但没有成功。
超参数调整步骤如下:
x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)
history = History()
kfold = 10
regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)
neurons = np.arange(10,101,10)
hidden_layers = [1,2]
optimizer = ['adam','sgd']
activation = ['relu']
dropout = [0.1]
parameters = dict(neurons = neurons,
hidden_layers = hidden_layers,
optimizer = optimizer,
activation = activation,
dropout = dropout)
gs = GridSearchCV(estimator = regressor,
param_grid = parameters,
scoring='mean_squared_error',
n_jobs = 1,
cv = kfold,
verbose = 3,
return_train_score=True))
grid_result = gs.fit(NN_input,
NN_target,
callbacks=[history],
verbose=1,
validation_data=(x_val, y_val))
注意:create_keras_model函数初始化并编译Keras序列模型。
在完成交叉验证后,我尝试使用以下代码保存网格搜索对象(gs):
from sklearn.externals import joblib
joblib.dump(gs, 'GS_obj.pkl')
我正在遇到的错误是以下内容:
TypeError: can't pickle _thread.RLock objects
请问这个错误的原因可能是什么?
谢谢!
附注:joblib.dump方法适用于保存用于训练sklearn MLPRegressors的GridSearchCV对象。