"Sequential"对象没有属性"_in_multi_worker_mode"。

9

我尝试使用谷歌Colab资源来保存我的CNN模型权重,但是我遇到了这个错误。我尝试在谷歌上搜索,但是没有什么帮助。

'Sequential'对象没有属性'_in_multi_worker_mode'

我的代码:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, verbose=1)


cnn_model = Sequential()
cnn_model.add(Conv2D(filters = 64, kernel_size = (3,3), activation = "relu", input_shape = Input_shape ))
cnn_model.add(Conv2D(filters = 64, kernel_size = (3,3), activation = "relu"))
cnn_model.add(MaxPooling2D(2,2))
cnn_model.add(Dropout(0.4))

cnn_model = Sequential()
cnn_model.add(Conv2D(filters = 128, kernel_size = (3,3), activation = "relu"))
cnn_model.add(Conv2D(filters = 128, kernel_size = (3,3), activation = "relu"))
cnn_model.add(MaxPooling2D(2,2))
cnn_model.add(Dropout(0.3))


cnn_model.add(Flatten())

cnn_model.add(Dense(units = 512, activation = "relu"))
cnn_model.add(Dense(units = 512, activation = "relu"))

cnn_model.add(Dense(units = 10, activation = "softmax"))

history = cnn_model.fit(X_train, y_train, batch_size = 32,epochs = 1, 
shuffle = True, callbacks = [cp_callback])

堆栈跟踪:

AttributeError                            Traceback (most recent call last)
<ipython-input-19-35c1db9636b7> in <module>()
----> 1 history = cnn_model.fit(X_train, y_train, batch_size = 32,epochs = 1, shuffle = True, callbacks = [cp_callback])

4 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_train_begin(self, logs)
    903   def on_train_begin(self, logs=None):
    904     # pylint: disable=protected-access
--> 905     if self.model._in_multi_worker_mode():
    906       # MultiWorkerTrainingState is used to manage the training state needed
    907       # for preemption-recovery of a worker in multi-worker training.

AttributeError: 'Sequential' object has no attribute '_in_multi_worker_mode'

欢迎来到stackoverflow。请查看如何创建一个最小可复现示例 - o-90
4个回答

17

我最近遇到了同样的问题

而不是,

from tensorflow.keras.callbacks import ModelCheckpoint
from keras.callbacks import ModelCheckpoint

3

检查您的tensorflow版本。您实际上只需要同步它。检查所有导入是否正确。

from keras import ...

或者
from tensorflow.keras import ...

只使用上述其中一种来导入你的Keras。同时使用不同的(两者)可能会导致库之间的冲突。


2

改为

tf.keras.callbacks.ModelCheckpoint

在您的模型构建过程中,您可以使用:
from keras.callbacks import ModelCheckpoint

为了导入ModelCheckpoint,然后在后续代码中只需使用ModelCheckpoint

0
请检查您的tensorflow版本是否与最新版本匹配。在我的情况下,当我将其更新到2.1.0时,错误得到了解决。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接