在Keras中,“model.trainable = False”是什么意思?

6
我希望在Keras中冻结一个预训练的网络。我在文档中找到了base.trainable = False,但我不明白它是如何工作的。通过使用len(model.trainable_weights),我发现我有30个可训练的权重。这怎么可能呢?该网络显示总共可训练参数:16,812,353。在冻结之后,我只剩下4个可训练的权重。也许我不理解参数和权重之间的区别。不幸的是,我是深度学习的初学者。或许有人可以帮帮我。
1个回答

11
一个Keras模型默认是可训练的 - 你有两种方法来冻结所有权重:
  1. model.trainable = False在编译模型之前
  2. for layer in model.layers: layer.trainable = False - 可以在编译前后使用
(1) 必须在编译之前完成,因为Keras将model.trainable视为布尔标志,并在编译时执行(2)。完成上述任一操作后,您应该会看到:
print(model.trainable_weights)
# [] 

关于文档,可能已过时 - 请参考上面链接的源代码,以获取最新信息。


如果一个Model在其layers中包含嵌套的Model,那么我认为选项(2)中的代码实际上会导致这些嵌套模型的选项(1)行为,对吗?在这种情况下,最好有一个函数set_trainable,它通过递归迭代layers,设置层的trainable并再次调用自身以处理嵌套模型。 - Kilian Batzner
@KilianBatzner 的确,递归方法更加安全:if isinstance(layer, Model): set_trainable(layer) - OverLordGoldDragon

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接