如何在tf.estimator中使用早停法?

22
我正在使用TensorFlow 1.4中的tf.estimatortf.estimator.train_and_evaluate功能非常出色,但我需要提前停止训练。如何添加此功能?我假设这里会有一些tf.train.SessionRunHook。我看到一个旧的contrib包含有ValidationMonitor,它似乎有一个提前停止的功能,但在1.4中似乎已经没有了。或者未来的首选方式是依赖于tf.keras (其中提前停止非常容易) 而不是tf.estimator/tf.layers/tf.data吗?
4个回答

33

好消息!tf.estimator现在在主分支上支持早停功能,看起来它将会出现在1.10版本中。

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

1
这看起来很有前途,但似乎不在r1.9中(我认为今天是稳定版本)>>> tf.contrib.estimator.stop_if_no_decrease_hook Traceback (most recent call last): File "<stdin>", line 1, in <module> AttributeError: module 'tensorflow.contrib.estimator' has no attribute 'stop_if_no_decrease_hook' - Eric H.
1
我尝试了1.10版本,但是出现了以下错误:NotFoundError: Key signal_early_stopping/STOP not found in checkpoint [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_INT64, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]] - Glrs
@TasosGlrs 你确定你没有尝试过添加钩子并从早期运行中创建的现有检查点继续,而该运行没有使用钩子,因此缺少必要的键吗? - oens
2
我也遇到了关于“signal_early_stopping not found”的类似问题。看起来是因为“early_stopping”钩子只能放置在TrainSpec钩子中。如果在EvalSpec钩子中使用,就会出现这个错误。 - LeckieNi
如果您不想遇到意外情况,建议将 run_every_secs=None, run_every_steps=50 进行设置,否则您的验证错误只会每隔大约10分钟才被考虑进去。 - Jill-Jênn Vie
显示剩余3条评论

3

首先,您必须为损失命名,以便将其提供给早期停止调用。如果您的估计器中的损失变量命名为“loss”,则该行

copyloss = tf.identity(loss, name="loss")

就在它下面可以起作用。

接下来,使用以下代码创建一个 hook。

class EarlyStopping(tf.train.SessionRunHook):
    def __init__(self,smoothing=.997,tolerance=.03):
        self.lowestloss=float("inf")
        self.currentsmoothedloss=-1
        self.tolerance=tolerance
        self.smoothing=smoothing
    def before_run(self, run_context):
        graph = ops.get_default_graph()
        #print(graph)
        self.lossop=graph.get_operation_by_name("loss")
        #print(self.lossop)
        #print(self.lossop.outputs)
        self.element = self.lossop.outputs[0]
        #print(self.element)
        return tf.train.SessionRunArgs([self.element])
    def after_run(self, run_context, run_values):
        loss=run_values.results[0]
        #print("loss "+str(loss))
        #print("running average "+str(self.currentsmoothedloss))
        #print("")
        if(self.currentsmoothedloss<0):
            self.currentsmoothedloss=loss*1.5
        self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
        if(self.currentsmoothedloss<self.lowestloss):
            self.lowestloss=self.currentsmoothedloss
        if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
            run_context.request_stop()
            print("REQUESTED_STOP")
            raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')

这段代码比较指数平滑的损失验证与其最低值,如果超过允许误差,就停止训练。如果停止得太早,增加允许误差和平滑处理会使它稍后停止。保持平滑处理小于1,否则它永远不会停止。如果您想根据其他条件停止,请将after_run中的逻辑替换为其他内容。现在,将此钩子添加到评估规范中。您的代码应该类似于这样:
eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])#

重要提示:在train_and_evaluate调用中,函数run_context.request_stop()失效,无法停止训练。因此,我引发了一个值错误来停止训练。所以你需要像这样将train_and_evaluate调用包装在try catch块中:

try:
    tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
    print("training stopped")

如果您不这样做,在训练停止时代码将崩溃并显示错误信息。


这似乎没有进行早停?如果我正确理解您的代码,您正在监控训练损失而不是验证损失。 - Carl Thomé
1
这个与EvalSpec相关联,因此它正在监控验证损失。如果训练时间足够长,它将进行早期停止。如果停止得不够快,您可能需要降低平滑值到0.99并降低容差。 - user3806120

2

是的,有一个tf.train.StopAtStepHook

这个钩子请求在执行了一定数量的步骤或到达最后一步之后停止。只能指定其中一个选项。

您也可以扩展它并根据步骤结果实现自己的停止策略。

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()

tf.train.StopAtStepHook似乎不能实现早停?但是,我猜我可以自己编写一个钩子来评估验证集,只是我期望它作为TensorFlow 1.4的内置功能。谢谢! - Carl Thomé
@CarlThomé 我明白你的意思。你是对的,tensorflow目前只捆绑了一些简单的会话钩子,并建议使用自己的钩子插入复杂的决策。 - Maxim
2
哪个变量可以帮助我在每一步中捕获 after_run 函数中的损失,以实现早停? - abhishek jha

1
另一种不使用hooks的选择是创建一个tf.contrib.learn.Experiment(即使在contrib中,它似乎也支持新的tf.estimator.Estimator)。然后通过适当自定义的continuous_eval_predicate_fn方法进行训练(显然是实验性的),使用continuous_train_and_eval。根据tensorflow文档,continuous_eval_predicate_fn是一个断言函数,用于确定是否在每次迭代后继续评估。该函数使用上次评估运行的eval_results调用。对于提前停止,使用一个自定义函数,该函数保留当前最佳结果和计数器作为状态,并在达到提前停止条件时返回False。注意:这种方法将使用tensorflow 1.7中弃用的方法(从那个版本开始,所有的tf.contrib.learn都已被弃用:https://www.tensorflow.org/api_docs/python/tf/contrib/learn)。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接