如何使用tf.estimator.train_and_evaluate实现在评估损失方面进行早停?

4

我正在使用TensorFlow估计器,明确使用方法tf.estimator.train_and_evaluate()。 训练中有一个早停挂钩,即tf.contrib.estimator.stop_if_no_decrease_hook,但我遇到了一个问题,训练损失太过跳动,无法使用早停挂钩。 有谁知道如何使用tf.estimator基于评估损失进行早停?

1个回答

2
您可以按照以下方式使用tf.contrib.estimator.stop_if_no_decrease_hook
您可以使用如下代码:
estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

但是如果对你不起作用,最好使用tf.estimator.experimental.stop_if_no_decrease_hook

例如:

estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

早停钩子使用评估结果来决定何时终止训练,但您需要传递要监视的训练步骤数量,并记住在该步骤数中将发生多少次评估。
如果将挂钩设置为hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000),则挂钩将考虑在10k步骤范围内发生的评估。
有关文档的更多信息,请参见此处,您可以使用所有早期停止函数,请参考此处

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接