Keras Sequential API 中的 training = False

7

当使用Keras Functional API调用预训练模型时,我们可以传递training = False参数,如教程所示。

如何在Keras Sequential API中实现相同的效果?

以下是我正试图使用Sequential API复制的代码:

inputs = tf.keras.Input( shape = ( TARGET_SIZE[0], TARGET_SIZE[1], 3 ) )
base_model = Xception( include_top = False, pooling = 'avg' )
base_model.trainable = False
x = base_model( inputs, training = False ) 
x = Dense( 512, activation = 'relu' )( x )
x = Dense( 256, activation = 'relu' )( x )
x = Dense( 128, activation = 'relu' )( x )
outputs = Dense( 6, activation = 'softmax' )( x )

以下是使用顺序API实现整个模型的代码,training = False没有使用,如下所示:
model = Sequential()
model.add( Xception( include_top = False, pooling = 'avg', input_shape = ( TARGET_SIZE[0], TARGET_SIZE[1], 3 ) ) )
model.add( Dense( units = 512, activation = 'relu' ) )
model.add( Dense( units = 256, activation = 'relu' ) )
model.add( Dense( units = 128, activation = 'relu' ) )
model.add( Dense( 6, activation = 'softmax' ) )

但是,我无法在其中加入training = False参数。


1
请将您的代码添加到问题中 - Andrey
@Andrey 代码已添加。@NicolasGervais 查询涉及到 training 属性而不是 trainable - Mrityu
@Mrityu training=False 应该是默认值。 - Andrey
@Mrityu training 参数通常用于使层在训练和评估中表现不同。据我所知,对于同一模型来说,它不应该有所不同。例如,通过添加 Dropout()training=False,实际上只是从模型中删除了该层。你能举个例子吗?什么情况下使用 training=False 会有意义呢? - Andrey
@Andrey 请参考此链接,并查看其中关于在猫狗数据集上进行端到端图像分类和微调示例中training = False的原因说明。 - Mrityu
显示剩余2条评论
2个回答

0
你可以尝试以下方法,这对我有效:
加载任何基础模型,包括_top = False,然后(例如):
 flat1 = Flatten()(base_model.layers[-1].output, training=False)
    
 class1 = Dense(5, activation='relu')(flat1)
    
 output = Dense(1, activation='sigmoid')(class1) 

 model = Model(inputs=base_model.inputs, outputs=output) 

0

我也曾经问过自己同样的问题。不幸的是,我没有找到一个清晰的解释。这里有关于迁移学习的Tensorflow指南。正如你所看到的,只有在调用模型或层时才能使用training = False。因此,我意识到你需要使用你写的第一段代码。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接