tf.keras.layers.BatchNormalization的trainable=False似乎无法更新其内部移动均值和方差。

5
我正在尝试找出BatchNormalization层在TensorFlow中的行为方式。我想出了下面这段代码,据我所知应该是一个完全有效的keras模型,但是BatchNormalization的平均值和方差似乎没有被更新。
从文档https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization中可以看到:
在BatchNormalization层的情况下,在该层上设置trainable = False意味着该层随后将在推理模式下运行(这意味着它将使用移动平均值和移动方差来规范化当前批次,而不是使用当前批次的平均值和方差)。
我期望模型每次预测调用时返回不同的值。然而,我看到的是10次返回完全相同的值。有人能解释一下为什么BatchNormalization层不会更新其内部值吗?
import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(3, 5) * 5 + 0.3

    bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
    z = input = tf.keras.layers.Input([5])
    z = bn(z)

    model = tf.keras.Model(inputs=input, outputs=z)

    for i in range(10):
        print(x)
        print(model.predict(x))
        print()

我使用TensorFlow 2.1.0

1个回答

10

好的,我在我的假设中找到了错误。移动平均值在训练期间被更新,而不是我之前认为的推理时期更新。这很有道理,因为在推理期间更新移动平均值可能会导致不稳定的生产模型(例如,高度病态输入样本(如其生成分布与网络训练的分布截然不同的长序列)可能潜在地导致网络偏差并在有效输入样本上表现更差)。

当您微调预训练模型并希望在训练期间冻结网络的某些层时,可训练参数非常有用。因为当您调用model.predict(x)(甚至model(x)model(x, training=False))时,该层自动使用移动平均值而不是批量平均值。

下面的代码清晰地演示了这一点:

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(10, 5) * 5 + 0.3

    z = input = tf.keras.layers.Input([5])
    z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)

    model = tf.keras.Model(inputs=input, outputs=z)
    
    # a dummy loss function
    model.compile(loss=lambda x, y: (x - y) ** 2)

    # a dummy fit just to update the batchnorm moving averages
    model.fit(x, x, batch_size=3, epochs=10)
    
    # first predict uses the moving averages from training
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # outputs the same thing as previous predict
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here calling the model with training=True results in update of moving averages
    # furthermore, it uses the batch mean and variance as in training, 
    # so the result is very different
    pred = model(x, training=True).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here we see again that the moving averages are used but they differ slightly after
    # the previous call, as expected
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()

最终,我发现文档(https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization)提到:

  1. 使用带有批归一化的模型进行推理时,通常(但不总是)最好使用累积统计数据而不是小批量统计数据。可以通过在调用模型时传递training=False或使用model.predict来实现。

希望这能帮助未来遇到类似误解的人。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接