问题
Tensorflow 包括函数 tf.train.ExponentialMovingAverage
,它允许我们对参数应用移动平均值,这对于稳定模型测试非常有帮助。
但是,我发现将其应用到一般模型上有些困难。迄今为止,我最成功的方法(如下所示)是编写一个函数装饰器,然后将整个 NN 放入一个函数中。
然而,这样做有几个缺点。首先,会复制整个图形,其次,需要在一个函数内定义 NN。
有更好的方法吗?
当前实现
def ema_wrapper(is_training, decay=0.99):
"""Use Exponential Moving Average of parameters during testing.
Parameters
----------
is_training : bool or `tf.Tensor` of type bool
EMA is applied if ``is_training`` is False.
decay:
Decay rate for `tf.train.ExponentialMovingAverage`
"""
def function(fun):
@functools.wraps(fun)
def fun_wrapper(*args, **kwargs):
# Regular call
with tf.variable_scope('ema_wrapper', reuse=False) as scope:
result_train = fun(*args, **kwargs)
# Set up exponential moving average
ema = tf.train.ExponentialMovingAverage(decay=decay)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope.name)
ema_op = ema.apply(var_class)
# Add to collection so they are updated
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
# Getter for the variables with EMA applied
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
# Call with EMA applied
with tf.variable_scope('ema_wrapper', reuse=True,
custom_getter=ema_getter):
result_test = fun(*args, **kwargs)
# Return the correct version depending on if we're training or not
return tf.cond(is_training,
lambda: result_train, lambda: result_test)
return fun_wrapper
return function
示例用法:
@ema_wrapper(is_training)
def neural_network(x):
# If is_training is False, we will use an EMA of a instead
a = tf.get_variable('a', [], tf.float32)
return a * x