我相信有很多方法可以做到这一点,但我自己尝试了一下并想出了自己的版本。
首先,一个定制的回调函数使得在每个时期结束时能够获取和更新历史记录。在这里我还有一个回调来保存模型。这两个都很方便,因为如果你崩溃或关闭,你可以从最后完成的时期重新开始训练。
class LossHistory(Callback):
def on_epoch_end(self, epoch, logs = None):
new_history = {}
for k, v in logs.items():
new_history[k] = [v]
current_history = loadHist(history_filename)
current_history = appendHist(current_history, new_history)
saveHist(history_filename, current_history)
model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
history_checkpoint = LossHistory()
callbacks_list = [model_checkpoint, history_checkpoint]
其次,以下是一些“辅助”函数,它们能够确切地执行它们所声明的任务。这些函数都从LossHistory()
回调函数中调用。
import json, codecs
def saveHist(path, history):
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4)
def loadHist(path):
n = {}
if os.path.exists(path):
with codecs.open(path, 'r', encoding='utf-8') as f:
n = json.loads(f.read())
return n
def appendHist(h1, h2):
if h1 == {}:
return h2
else:
dest = {}
for key, value in h1.items():
dest[key] = value + h2[key]
return dest
接下来,您只需要将history_filename
设置为类似于data/model-history.json
的内容,同时将model_filename
设置为类似于data/model.h5
的内容。在训练结束时,为了确保不会破坏您的历史记录(假设您停止并重新开始),并将回调函数插入其中,可以进行最后的微调:
new_history = model.fit(X_train, y_train,
batch_size = batch_size,
nb_epoch = nb_epoch,
validation_data=(X_test, y_test),
callbacks=callbacks_list)
history = appendHist(history, new_history.history)
每当您需要时,history = loadHist(history_filename)
就可以将您的历史记录加载回来。
这个方法的巧妙之处在于使用了JSON和列表,但是我没有办法在不进行迭代转换的情况下使其正常工作。不管怎样,我知道这个方法可行,因为我已经在上面工作了数天。可能在 https://dev59.com/hlgR5IYBdhLWcg3wp-q5#44674337 中使用 pickle.dump
更好,但我不知道那是什么。如果我遗漏了什么或者您无法使其正常工作,请告诉我。
CSVLogger()
回调函数,如此处所述:https://keras.io/callbacks/#csvlogger - swiss_knightfit
返回的 history 对象吗?它包含了.params
属性中的有用信息,我也想保留它。是的,我可以分别保存params
和history
属性,或者将它们组合成一个字典,但我对一种简单的保存整个history
对象的方法感兴趣。 - user3731622