批量归一化 - 在 TensorFlow 中提取运行均值和运行方差

3
我正在尝试查看通过GCMLE(saved_model.pb,assets/*和variables/*)导出的已训练TensorFlow模型的运行平均值和方差。这些值在图中保存在哪里?我可以从tf.GraphKeys.TRAINABLE_VARIABLES访问gamma/beta值,但我无法在任何tf.GraphKeys.MODEL_VARIABLES中找到运行时平均值和方差。运行平均值和方差是否存储在其他地方?
我知道在测试时(即Modes.EVAL),运行平均值和方差用于归一化传入的数据,然后使用gamma和beta对归一化数据进行缩放和移位。我正在尝试查看推理时间需要的所有变量,但是我找不到运行平均值和方差。它们只在测试时而不是推理时间(Modes.PREDICT)使用吗?如果是这样,那就解释了为什么我在导出的模型中找不到它们,但我预计它们会在那里。
根据tf.GraphKeys,我尝试了其他方法,如tf.GraphKeys.MOVING_AVERAGE_VARIABLES,但它们也是空的。我还在批处理规范的文档中看到了这行“注意:在训练时,需要更新移动均值和移动方差。默认情况下,更新操作放置在tf.GraphKeys.UPDATE_OPS中,因此它们需要作为train_op的依赖项添加。”所以我尝试查看我保存的模型中的tf.GraphKeys.UPDATE_OPS,它们包含一个assign op batch_normalization/AssignMovingAvg:0,但仍不清楚我从哪里获取这个值。
2个回答

1
看起来移动平均值和移动方差存储在tf.GraphKeys.GLOBAL_VARIABLES中,似乎之所以在MODEL_VARIABLES中没有显示任何内容是因为需要使用tf.contrib.framework.local_variable.

0
除了 #reese0106 的回答之外,
如果您想要删除 BatchNorm 中的 moving_mean 和 moving_variance,
您可以按以下名称对它们进行索引。
vars = tf.global_variables() # shows every variable being used.
vars_moving_mean_variance = []
for var in vars:
    if ("moving_mean" in var.name) or ("moving_variance" in var.name):
        vars_moving_mean_variance.append(var)

print(vars_moving_mean_variance)


附言:感谢您的问题和答案。我也解决了自己的问题。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接