我使用自定义层建立了一个Keras模型,并通过回调函数ModelCheckPoint
将其保存到.h5
文件中。
训练结束后,当我尝试加载该模型时,出现以下错误消息:
这是自定义层及其
__init__() missing 1 required positional argument: 'pool_size'
__init__
方法的定义:
class MyMeanPooling(Layer): def __init__(self, pool_size, axis=1, **kwargs): self.supports_masking = True self.pool_size = pool_size self.axis = axis self.y_shape = None self.y_mask = None super(MyMeanPooling, self).__init__(**kwargs)
这是我如何将这个层添加到我的模型中:
x = MyMeanPooling(globalvars.pool_size)(x)
这是我加载模型的方式:
from keras.models import load_model model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
以下是完整的错误信息:
Traceback (most recent call last): File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module> model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling}) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model model = _deserialize_model(f, custom_objects, compile) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model model = model_from_config(model_config, custom_objects=custom_objects) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config return deserialize(config, custom_objects=custom_objects) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize printable_module_name='layer') File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object list(custom_objects.items()))) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config process_layer(layer_data) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer custom_objects=custom_objects) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize printable_module_name='layer') File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object return cls.from_config(config['config']) File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config return cls(**config) TypeError: __init__() missing 1 required positional argument: 'pool_size'