如何正确从.h5文件中加载带有自定义层的Keras模型?

11

我使用自定义层建立了一个Keras模型,并通过回调函数ModelCheckPoint将其保存到.h5文件中。 训练结束后,当我尝试加载该模型时,出现以下错误消息:

__init__() missing 1 required positional argument: 'pool_size'
这是自定义层及其__init__方法的定义:
class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = pool_size
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

这是我如何将这个层添加到我的模型中:

x = MyMeanPooling(globalvars.pool_size)(x)

这是我加载模型的方式:

from keras.models import load_model

model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})

以下是完整的错误信息:

Traceback (most recent call last):
  File "D:/My Projects/Attention_BLSTM/script3.py", line 9, in <module>
    model = load_model(model_path, custom_objects={'MyMeanPooling': MyMeanPooling})
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 419, in load_model
    model = _deserialize_model(f, custom_objects, compile)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
    return deserialize(config, custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1022, in from_config
    process_layer(layer_data)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\network.py", line 1008, in process_layer
    custom_objects=custom_objects)
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\utils\generic_utils.py", line 147, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "D:\ProgramData\Anaconda3\envs\tf\lib\site-packages\keras\engine\base_layer.py", line 1109, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'pool_size'

你在你的Layer子类中实现了哪些方法? - Dr. Snoopy
这是因为Keras调用了您的层的构造函数,但它需要一个位置参数,即“pool_size”。(Keras不提供此参数) - Theophile Champion
3个回答

10

实际上,我认为您无法加载此模型。

最可能的问题是您没有在层中实现get_config()方法。此方法返回一个应该被保存的配置值字典:

def get_config(self):
    config = {'pool_size': self.pool_size,
              'axis': self.axis}
    base_config = super(MyMeanPooling, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

在将此方法添加到您的层后,您必须重新训练模型,因为先前保存的模型没有将此层的配置保存其中。这就是为什么您无法加载它,需要在进行此更改后重新训练。


@waleema 当然,不用谢。但是你应该根据问题是否解决或对你有用来投票和/或接受问题。 - Dr. Snoopy
我已经为你投票了,但似乎我的投票并没有改变公开显示的帖子分数,因为我的声望不到15,但系统告诉我我的投票已记录。再次感谢! - liang
@waleema,你所做的是接受答案,这与赞/踩投票是分开的,但没关系,我们只是想让你知道SO系统的工作方式 :) - Dr. Snoopy
是的,我接受了你的答案,并且也投了票给你。 - liang

2
以下是“LiamHe在2017年9月27日评论”以下问题的答案的翻译:https://github.com/keras-team/keras/issues/4871
我今天遇到了同样的问题:** TypeError:init()缺少1个必需的位置参数**。这里是我解决问题的方法:(Keras 2.0.2)
  1. 给该层的位置参数一些默认值
  2. 覆盖具有某些内容的get_config函数的层
将原始答案翻译为“最初的回答”。
def get_config(self):
    config = super().get_config()
    config['pool_size'] = # say self._pool_size  if you store the argument in __init__
    return config

在加载模型时,为custom_objects添加层类。"Original Answer"翻译成"最初的回答"。

非常感谢!你的回答很有帮助。 - liang

0

如果您没有足够的时间按照Matias Valdenegro的解决方案重新训练模型,您可以像以下代码一样在MyMeanPooling类中设置pool_size的默认值。请注意,在训练模型时,pool_size的值应与此处保持一致。然后您就可以加载模型了。

class MyMeanPooling(Layer):
    def __init__(self, pool_size, axis=1, **kwargs):
        self.supports_masking = True
        self.pool_size = 2  # The value should be consistent with the value while training the model
        self.axis = axis
        self.y_shape = None
        self.y_mask = None
        super(MyMeanPooling, self).__init__(**kwargs)

参考:https://www.jianshu.com/p/e97112c34e43


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接