Tensorflow,负KL散度

7

我正在使用变分自编码器(Variational Autoencoder)类型的模型进行工作,我的损失函数的一部分是 KL 散度,它衡量了一个均值为0、方差为1的正态分布与另一个正态分布之间的差异程度。后者的均值和方差由我的模型预测。

我按如下方式定义了损失:

def kl_loss(mean, log_sigma):
    normal=tf.contrib.distributions.MultivariateNormalDiag(tf.zeros(mean.get_shape()),
                                                           tf.ones(log_sigma.get_shape()))
    enc_normal = tf.contrib.distributions.MultivariateNormalDiag(mean,
                                                                     tf.exp(log_sigma),
                                                                     validate_args=True,
                                                                     allow_nan_stats=False,
                                                                     name="encoder_normal")
    kl_div = tf.contrib.distributions.kl_divergence(normal,
                                                    enc_normal,
                                                    allow_nan_stats=False,
                                                    name="kl_divergence")
return kl_div

输入是长度为N的无约束向量。
log_sigma.get_shape() == mean.get_shape()

现在在训练期间,我观察到负的KL散度,在几千次迭代后达到-10的值。您可以在下面看到Tensorboard训练曲线: KL散度曲线 KL散度曲线放大图 现在这对我来说似乎很奇怪,因为在某些条件下KL散度应该是正的。 我知道我们需要“仅当P和Q都总和为1并且如果存在P(i)> 0,则Q(i)> 0时才定义K-L散度。”(参见https://mathoverflow.net/questions/43849/how-to-ensure-the-non-negativity-of-kullback-leibler-divergence-kld-metric-rela),但我不知道在我的情况下如何违反它。非常感谢您的任何帮助!

你的最后一层使用的激活函数是什么? - Vikash Singh
最后一层是一个带有线性(无)激活函数(https://www.tensorflow.org/api_docs/python/tf/layers/conv3d)和核大小为1的3D卷积层。我将结果张量展平,前半部分成为我的均值,后半部分成为log_sigma。 - Prook
那么最后一层的输出可以大于1吗? - Vikash Singh
1
嗯,我知道。从代码片段和我的帖子中可以很清楚地看出,我正在使用对角协方差矩阵初始化两个多元正态分布。其中一个分布的均值和方差由我的网络输出确定。唯一的限制是sigma是正数。这通过取exp(log-sigma)来确保。那你具体在问些什么呢? - Prook
4
这与输入无关,但你显然没有仔细阅读我所写的内容。上面的函数应该对任何输入返回正的KL散度。我怀疑这是一个数值问题,可能与Tensorflow中实现该函数的方式有关。我只是想听听其他人对这个问题的看法,特别是那些遇到过同样问题的人。 - Prook
显示剩余3条评论
1个回答

1
面临相同的问题。 这是由于使用了浮点精度所导致的。 如果您注意到负值接近0并且被限制在小的负值范围内。向损失添加一个小的正值是一种解决方法。

2
嗨!请在您的回答中更详细地说明并提供解决方案,或者您可以将此答案移至评论部分。 - Aniket Tiratkar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接