如果您正在使用使用内部计数器的tf.metrics操作,则可以在验证集上使用批处理。以下是一个简化的示例:
model = create_model()
tf.summary.scalar('cost', model.cost_op)
acc_value_op, acc_update_op = tf.metrics.accuracy(labels,predictions)
summary_common = tf.summary.merge_all()
summary_valid = tf.summary.merge([
tf.summary.scalar('accuracy', acc_value_op),
])
with tf.Session() as sess:
train_writer = tf.summary.FileWriter(logs_path + '/train',
sess.graph)
valid_writer = tf.summary.FileWriter(logs_path + '/valid')
在训练过程中,只需使用您的训练写手来撰写常见摘要。
summary = sess.run(summary_common)
train_writer.add_summary(summary, tf.train.global_step(sess, gstep_op))
train_writer.flush()
在每次验证后,使用valid-writer编写两个摘要:
gstep, summaryc, summaryv = sess.run([gstep_op, summary_common, summary_valid])
valid_writer.add_summary(summaryc, gstep)
valid_writer.add_summary(summaryv, gstep)
valid_writer.flush()
在使用tf.metrics时,不要忘记在每个验证步骤之前重置内部计数器(本地变量)。
batch_size=128
。最后一个批次的大小不是128,因此不能简单地应用平均值。 - huangbiubiu