我正在使用TensorFlow估计器,明确使用方法tf.estimator.train_and_evaluate()
。
训练中有一个早停挂钩,即tf.contrib.estimator.stop_if_no_decrease_hook
,但我遇到了一个问题,训练损失太过跳动,无法使用早停挂钩。
有谁知道如何使用tf.estimator
基于评估损失进行早停?
我正在使用TensorFlow估计器,明确使用方法tf.estimator.train_and_evaluate()
。
训练中有一个早停挂钩,即tf.contrib.estimator.stop_if_no_decrease_hook
,但我遇到了一个问题,训练损失太过跳动,无法使用早停挂钩。
有谁知道如何使用tf.estimator
基于评估损失进行早停?
estimator = tf.estimator.Estimator(model_fn, model_dir)
os.makedirs(estimator.eval_dir()) # TODO This should not be expected IMO.
early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
estimator,
metric_name='loss',
max_steps_without_decrease=1000,
min_steps=100)
tf.estimator.train_and_evaluate(
estimator,
train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
eval_spec=tf.estimator.EvalSpec(eval_input_fn))
但是如果对你不起作用,最好使用tf.estimator.experimental.stop_if_no_decrease_hook。
例如:
estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000)
,则挂钩将考虑在10k步骤范围内发生的评估。