在TF1中以图形模式运行时,我认为我需要通过feeddicts连接
我认为在使用
我可以假设使用函数式API进行训练时,Keras会自动处理这个问题吗?下面是使用函数式API重写的相同模型:
training=True
和training=False
,当我使用函数式API时。在TF2中,正确的做法是什么?我认为在使用
tf.keras.Sequential
时,这会自动处理。例如,我不需要在以下示例中指定training
,来自docs:model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10, activation='softmax')
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
我可以假设使用函数式API进行训练时,Keras会自动处理这个问题吗?下面是使用函数式API重写的相同模型:
inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)
# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))
我不确定 hid = tf.keras.layers.BatchNormalization()(hid)
是否需要变成 hid = tf.keras.layers.BatchNormalization()(hid, training)
?
这些模型的 colab 可以在 这里 找到。
model_fn()
(tf.keras.Model#call
)的前向传递中设置它,以便BatchNormalization行为正确。我认为我需要子类化模型并显式定义前向传递调用,以便我可以将“training”传递给BN调用,类似于https://www.tensorflow.org/api_docs/python/tf/keras/Model中的示例。我还想知道在使用model_fn.fit()
时是否需要。 - cosentiyestf.keras.Sequential
时这是自动处理的。你确定吗?你有任何可以证明这一点的参考资料吗? - Nerxis