Keras回调函数出现AttributeError: 'ModelCheckpoint'对象没有属性'_implements_train_batch_hooks'。

6
我正在使用Keras(带有TensorFlow后端)实现神经网络,并且想要在训练期间仅保存在验证集上损失最小的模型。为此,我实例化了一个ModelCheckpoint对象,并在调用fit方法时将其传递给模型。然而,当我这样做时,我会收到以下错误:"AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'"。我在网上找到的与我的问题最接近的是这个帖子,其中出现了类似的错误,原因是混合使用了kerastf.keras模块,但这不是我的情况,因为我所有的模块都是从keras导入的。我已经在网上和Keras文档中寻找了一段时间,但找不到任何可以解释这个错误的东西。以下是代码中似乎与问题最相关的部分:
导入的模块:
from keras.models import Sequential
from keras.layers import Embedding, Conv1D, Dense, Dropout, GlobalMaxPool1D, Concatenate
from keras.callbacks import ModelCheckpoint

ModelCheckpoint 实例化,模型编译和调用fit方法:

checkpoint = ModelCheckpoint('../model_best.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min')

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

history = model.fit(x_train, y_train, 
                    epochs = 10, batch_size = 64,
                    validation_data = (x_val, y_val),
                    callbacks = [checkpoint])

...这是完整的回溯信息:

Traceback (most recent call last):

  File "/Users/thisuser/thisrepo/classifier.py", line 39, in <module>
    callbacks = [checkpoint])

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 826, in fit
    steps=data_handler.inferred_steps)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in __init__
    cb._implements_train_batch_hooks() for cb in self.callbacks)

  File "/Users/thisuser/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 231, in <genexpr>
    cb._implements_train_batch_hooks() for cb in self.callbacks)

AttributeError: 'ModelCheckpoint' object has no attribute '_implements_train_batch_hooks'

我使用的版本是:
  • Python: 3.7.7
  • Keras: 2.3.0-tf
请问有人知道可能是什么原因导致了这个问题吗?如果需要,我可以稍微修改一下我的代码,以便在此提供所有内容,以便能够再现。非常感谢您的帮助!
2个回答

7

最近我也遇到了这个问题。

我的发现是:最近Keras或TensorFlow版本被开发人员 更新 了,导致了这个问题。

解决方法:由于Keras的开发者要求每个人都切换到tf.keras版本,你需要更改你的代码 import 部分

从:

import keras

To:

import tensorflow.keras as keras

之后一切都对我起作用了。


0

替换: from keras.callbacks import ModelCheckpoint 为: from tensorflow.keras.callbacks import ModelCheckpoint


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接