加载自定义层的模型时出现TypeError: __init__() got an unexpected keyword argument 'name'错误

12

我在Keras中制作了一个自定义图层,用于在馈送到ConvLSTM2D层之前重新整形CNN的输出

class TemporalReshape(Layer):
    def __init__(self,batch_size,num_patches):
        super(TemporalReshape,self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

    def call(self,inputs):
        nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
        return tf.reshape(inputs, nshape)

    def get_config(self):
        config = super().get_config().copy()
        config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
        return config

尝试使用最佳模型加载时

model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})

我收到了错误信息

TypeError                                 Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})


/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
    180     if (h5py is not None and (
    181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
    183 
    184     filepath = path_to_string(filepath)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
    176     model_config = json.loads(model_config.decode('utf-8'))
    177     model = model_config_lib.model_from_config(model_config,
--> 178                                                custom_objects=custom_objects)
    179 
    180     # set weights

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
     53                     '`Sequential.from_config(config)`?')
     54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
---> 55   return deserialize(config, custom_objects=custom_objects)
     56 
     57 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    356             custom_objects=dict(
    357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
    360         return cls.from_config(cls_config)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
    615     """
    616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617         config, custom_objects)
    618     model = cls(inputs=input_tensors, outputs=output_tensors,
    619                 name=config.get('name'))

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
   1202   # First, we create all layers and enqueue nodes to be processed
   1203   for layer_data in config['layers']:
-> 1204     process_layer(layer_data)
   1205   # Then we process nodes in order of layer depth.
   1206   # Nodes that cannot yet be processed (if the inbound node

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
   1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
   1185 
-> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
   1187       created_layers[layer_name] = layer
   1188 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
    173       module_objects=LOCAL.ALL_OBJECTS,
    174       custom_objects=custom_objects,
--> 175       printable_module_name='layer')

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    358                 list(custom_objects.items())))
    359       with CustomObjectScope(custom_objects):
--> 360         return cls.from_config(cls_config)
    361     else:
    362       # Then `cls` may be a function returning a class.

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
    695         A layer instance.
    696     """
--> 697     return cls(**config)
    698 
    699   def compute_output_shape(self, input_shape):

TypeError: __init__() got an unexpected keyword argument 'name'

在构建模型时,我使用了如下所示的自定义层:

x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

是什么导致了错误,以及如何在没有错误的情况下加载模型?


3
如果在__init__中加入**kwargs会怎么样? - Nicolas Gervais
1
@NicolasGervais 把你的评论发表为答案,我会接受它。你在上面的评论中说的话起作用了。非常感谢!def __init__(self,batch_size,num_patches,**kwargs): - Siladittya
那很有趣。 - Nicolas Gervais
2个回答

19

仅基于错误消息,我建议将**kwargs放在__init__中。这个对象将接受任何你没有包含的其他关键字参数。

def __init__(self, batch_size, num_patches, **kwargs):
        super(TemporalReshape, self).__init__(**kwargs) # <--- must, thanks https://stackoverflow.com/users/349130/dr-snoopy
        self.batch_size = batch_size
        self.num_patches = num_patches

2
这是正确的,但你缺少了一个关键的东西,kwargs需要传递给父类的init函数。 - Dr. Snoopy
2
像这样吗?super(TemporalReshape,self).__init__(**kwargs) - Nicolas Gervais
是的,那就是我的意思。 - Dr. Snoopy
但是,即使没有那个错误,我也没有收到任何错误。但还是谢谢你的建议。 - Siladittya
这是因为缺少的kwargs具有默认值。因此,您面临的风险是无法重建与序列化完全相同的内容。 - wessel
在添加这个之后,我得到了一个错误,缺少1个必需的位置参数。 - stic-lab

5

__init__()函数中加入**kwargs参数。

错误信息:"TypeError: __init__()缺少3个必须位置参数:'batch_size'、'num_patches'"


2
这不是一个答案,最近的编辑也与原始回答者所说的不符。此外,这个编辑是另一个答案的完美复制。 - Nicolas Gervais

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,