在TensorFlow中的批量归一化

5

我注意到TensorFlow的API中已经有了批量归一化函数。但是,有一件事我不明白,就是如何在训练和测试之间改变过程?

批量归一化在测试和训练期间表现不同。具体来说,在训练期间使用固定的均值和方差。

是否有一些好的示例代码可以参考?我看到了一些,但是由于作用域变量,它们变得令人困惑。


考虑使用高级API(例如tf.contrib.layers)中预定义的层。 - danijar
1个回答

9
您说得对,tf.nn.batch_normalization仅提供了实现批量归一化的基本功能。在训练过程中,您需要添加额外的逻辑来跟踪移动平均值和方差,并在推理时使用训练好的均值和方差。您可以查看这个示例进行一个非常通用的实现,但是一个不使用gamma的快速版本如下:
  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)

非常感谢。还有一个快速的问题。带有伽马的版本真的更复杂吗?看起来你只需要为其初始化另一个tf.Variable即可,其他代码不应该改变,对吗? - user3358117
是的,您可以遵循我提供的链接中更一般的实现来添加“gamma”。 - keveman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接