首先,您必须为损失命名,以便将其提供给早期停止调用。如果您的估计器中的损失变量命名为“loss”,则该行
copyloss = tf.identity(loss, name="loss")
就在它下面可以起作用。
接下来,使用以下代码创建一个 hook。
class EarlyStopping(tf.train.SessionRunHook):
def __init__(self,smoothing=.997,tolerance=.03):
self.lowestloss=float("inf")
self.currentsmoothedloss=-1
self.tolerance=tolerance
self.smoothing=smoothing
def before_run(self, run_context):
graph = ops.get_default_graph()
self.lossop=graph.get_operation_by_name("loss")
self.element = self.lossop.outputs[0]
return tf.train.SessionRunArgs([self.element])
def after_run(self, run_context, run_values):
loss=run_values.results[0]
if(self.currentsmoothedloss<0):
self.currentsmoothedloss=loss*1.5
self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)
if(self.currentsmoothedloss<self.lowestloss):
self.lowestloss=self.currentsmoothedloss
if(self.currentsmoothedloss>self.lowestloss+self.tolerance):
run_context.request_stop()
print("REQUESTED_STOP")
raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook')
这段代码比较指数平滑的损失验证与其最低值,如果超过允许误差,就停止训练。如果停止得太早,增加允许误差和平滑处理会使它稍后停止。保持平滑处理小于1,否则它永远不会停止。如果您想根据其他条件停止,请将after_run中的逻辑替换为其他内容。现在,将此钩子添加到评估规范中。您的代码应该类似于这样:
eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])
重要提示:在train_and_evaluate调用中,函数run_context.request_stop()失效,无法停止训练。因此,我引发了一个值错误来停止训练。所以你需要像这样将train_and_evaluate调用包装在try catch块中:
try:
tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec)
except ValueError as e:
print("training stopped")
如果您不这样做,在训练停止时代码将崩溃并显示错误信息。
>>> tf.contrib.estimator.stop_if_no_decrease_hook Traceback (most recent call last): File "<stdin>", line 1, in <module> AttributeError: module 'tensorflow.contrib.estimator' has no attribute 'stop_if_no_decrease_hook'
- Eric H.NotFoundError: Key signal_early_stopping/STOP not found in checkpoint [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_INT64, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
。 - Glrsrun_every_secs=None, run_every_steps=50
进行设置,否则您的验证错误只会每隔大约10分钟才被考虑进去。 - Jill-Jênn Vie