在使用TensorFlow 2.3.0和Keras 2.4.3时,无法加载在Keras 2.1.0(使用TensorFlow 1.3.0)中保存的Keras模型。

8
我正在实现一个带有自定义批次归一化层的Keras模型,该层具有4个权重(beta、gamma、running_mean和running_std)和3个状态变量(r_max、d_max和t):
    self.gamma = self.add_weight(shape = shape, #NK - shape = shape
                                 initializer=self.gamma_init,
                                 regularizer=self.gamma_regularizer,
                                 name='{}_gamma'.format(self.name))
    self.beta = self.add_weight(shape = shape, #NK - shape = shape
                                initializer=self.beta_init,
                                regularizer=self.beta_regularizer,
                                name='{}_beta'.format(self.name))
    self.running_mean = self.add_weight(shape = shape, #NK - shape = shape
                                        initializer='zero',
                                        name='{}_running_mean'.format(self.name),
                                        trainable=False)
    # Note: running_std actually holds the running variance, not the running std.
    self.running_std = self.add_weight(shape = shape, initializer='one',
                                       name='{}_running_std'.format(self.name),
                                       trainable=False)
    self.r_max = K.variable(np.ones((1,)), name='{}_r_max'.format(self.name))

    self.d_max = K.variable(np.zeros((1,)), name='{}_d_max'.format(self.name))

    self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))

当我对模型进行检查点时,只保存了gamma、beta、running_mean和running_std(预期的结果),但是当我尝试加载模型时,却出现了以下错误:

Layer #1 (named "batch_renormalization_1" in the current model) was found to correspond to layer batch_renormalization_1 in the save file. However the new layer batch_renormalization_1 expects 7 weights, but the saved weights have 4 elements. 

看起来模型期望保存的文件中包含所有 7 个权重,即使其中一些是状态变量。

有什么想法可以解决这个问题吗?

编辑:我意识到问题在于该模型是在 Keras 2.1.0(使用 Tensorflow 1.3.0 后端)上训练和保存的,而只有在使用 Keras 2.4.3(使用 Tensorflow 2.3.0 后端)加载模型时才会出现错误。我能够使用 Keras 加载到 2.1.0 的模型。

所以真正的问题是 - Keras/Tensorflow 中发生了什么变化,有没有办法在不收到此错误的情况下加载旧模型?

2个回答

0

你不能用这种方式加载模型,因为keras.models.load_model将会加载已经定义好的配置,而不是自定义的内容。

为了解决这个问题,你应该重新加载模型架构,并尝试从中加载权重:

model = YourModelDeclaration()
model.load_weights("checkpoint/h5file")

当我自定义BatchNormalize时,我遇到了同样的问题,所以我很确定这是加载它的唯一方法。


感谢您的回复,但是load_weights也不起作用。经过一番搜索,我发现错误实际上是由于在不同版本的Keras/Tensorflow上保存并尝试加载而引起的。因此,真正的问题是是否有一种方法可以加载在旧版本的Keras中保存的模型,而不会遇到这个错误。 - Nick Koprowicz
1
显然,你不应该仅在不同版本中保存权重,然后在更新的版本中加载它。但奇怪的是,即使保存整个模型也不能让你来回移动 :D - dtlam26
1
这似乎是一个常见的问题。在这种情况下,我试图使用别人创建的模型,但他们没有提供有关使用哪个版本的Tensorflow/Keras的详细信息。因此,需要进行一些猜测和尝试才能使事情正常运行。对我来说,模型加载应该跨版本工作,否则共享模型将非常困难。 - Nick Koprowicz

0
在Keras中,有两种方法可以保存模型的状态。
您可以调用model.save()model.save_weights()函数。 model.save()保存整个模型,包括权重和梯度。在您的情况下,这种方法将保存4个权重和3个状态变量。您只需使用load_model(“path.h5”)方法即可获取您的模型。 model.save_weights()函数仅保存模型的权重,并且根本不保存结构。这里需要注意的重要事项是Keras检查点回调在幕后使用model.save_weights()方法。如果您希望使用检查点权重,则必须实例化您的模型结构model = customModel(),然后将权重加载到其中model.load_weights("checkpoint.h5")

感谢您的回复,但是load_weights也不起作用。经过一番搜索,我发现错误实际上是由于在不同版本的Keras/Tensorflow上保存并尝试加载而引起的。因此,真正的问题是是否有一种方法可以加载在旧版本的Keras中保存的模型,而不会遇到这个错误。 - Nick Koprowicz
据我所知,在保存和加载模型时,您不能在Keras/tf版本之间切换。 - Vaibhav Mehrotra

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接