Keras LSTM 保存后如何继续训练

3
我正在开发一个LSTM模型,希望能够保存模型并在以后随着数据的积累继续训练。我的问题是,当我保存模型并在下一次运行脚本时重新加载它时,预测结果完全错误,只是模仿了我输入的数据。
以下是模型的初始化代码:
# create and fit the LSTM network
if retrain == 1:
    print "Creating a newly retrained network."
    model = Sequential()
    model.add(LSTM(inputDimension, input_shape=(1, inputDimension)))
    model.add(Dense(inputDimension, activation='relu'))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(trainX, trainY, epochs=epochs, batch_size=batch_size, verbose=2)
    model.save("model.{}.h5".format(interval))
else:
    print "Using an existing network."
    model = load_model("model.{}.h5".format(interval))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(trainX, trainY, epochs=epochs, batch_size=batch_size, verbose=2)
    model.save("model.{}.h5".format(interval))
    del model
    model = load_model("model.{}.h5".format(interval))
    model.compile(loss='mean_squared_error', optimizer='adam')

当retrain设置为1时,第一个数据集约有10,000个条目,大约有3k个epoch和5%的批量大小。第二个数据集是单个条目数据,即一行,同样具有3k个epochs和batch_size = 1。

已解决

我错误地重新加载了缩放器:

scaler = joblib.load('scaler.{}.data'.format(interval))
dataset = scaler.fit_transform(dataset)

正确:

scaler = joblib.load('scaler.{}.data'.format(interval))
dataset = scaler.transform(dataset)

fit_transform重新计算了经缩放后数值的乘数,这意味着与原始数据存在偏差。

1个回答

1

来自keras模型API的功能model.fit():

initial_epoch:整数。训练开始的时期(对于恢复以前的训练运行很有用)。

设置此参数可能会解决您的问题。

我认为问题的根源是Adam的自适应学习率。在训练过程中,学习率自然下降以进行更好地微调模型。当您仅使用一个样本重新训练模型时,由于重置了学习率,权重更新可能会太大,这可能会完全破坏以前的权重。

如果initial_epoch不好,则尝试以较低的学习率开始第二次训练。


这是由于缩放器引起的可能吗? 我对数据集执行 scaler.fit_transform,在第一个数据集中,数据的变化范围从 6000-8000 不等,但是由于第二个数据集只有单个值,因此缩放器会专门为该值缩放数据,而不是为第一个数据集的相同范围缩放。 - Toomas-Siim Teresk
这也可能会有问题,但似乎与此无关。通常情况下,即使数据被错误缩放,如果您只进行一次权重更新,它也不应破坏您的准确性。对于缩放器问题:您能否将缩放器对象进行pickle并在第二次运行时调用而无需进行fit操作? - dennis-w
没有使用fit函数,预测看起来还不错,也尝试了initial_epochs和0.001的学习率,但都没有帮助。奇怪的是,当我加载模型并显示学习率时,它自动变成了0。 - Toomas-Siim Teresk
如果我改变数据集的大小,那么训练就没有问题,预测会正常工作,无论是否进行训练。 - Toomas-Siim Teresk
我刚刚看到另一件事:为什么在load_model之后还要编译模型?我不认为这是必要的。也许这就是你问题的原因。 - dennis-w
嗯,可能是的。但问题已经解决了,感谢你的想法! - Toomas-Siim Teresk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接