无法从DenseVariational获得合理的结果

3
我正在尝试使用以下数据集(正弦曲线)进行回归问题,数据集大小为500

数据集

首先,我尝试了两个具有10个单元的密集层。
model = tf.keras.Sequential([
        tf.keras.layers.Dense(10, activation='tanh'),
        tf.keras.layers.Dense(10, activation='tanh'),
        tf.keras.layers.Dense(1),
        tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1.))
    ])

训练时使用如下的负对数似然损失函数:
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01), loss=neg_log_likelihood)
model.fit(x, y, epochs=50)

结果绘图 无不确定性

接下来,我尝试了使用 DenseVariational 的类似环境。

model = tf.keras.Sequential([
        tfp.layers.DenseVariational(
            10, activation='tanh', make_posterior_fn=posterior,
            make_prior_fn=prior, kl_weight=1/N, kl_use_exact=True),
        tfp.layers.DenseVariational(
            10, activation='tanh', make_posterior_fn=posterior,
            make_prior_fn=prior, kl_weight=1/N, kl_use_exact=True),
        tfp.layers.DenseVariational(
            1, activation='tanh', make_posterior_fn=posterior,
            make_prior_fn=prior, kl_weight=1/N, kl_use_exact=True),
        tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1.))
    ])

由于参数数量翻倍,我已经尝试增加数据集大小和/或epoch大小,多达100倍都没有成功。通常的结果如下所示。

With uncertainty

我的问题是,如何获得与 Dense 层相当的 DenseVariational 结果?我也听说它对初始值很敏感。在这里是完整代码的链接。欢迎提出任何建议。

3个回答

3
你需要定义一个不同的代理后验分布。在Tensorflow贝叶斯线性回归示例中,https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Probabilistic_Layers_Regression.ipynb#scrollTo=VwzbWw3_CQ2z,你可以将后验均值场定义如下:
# Specify the surrogate posterior over `keras.layers.Dense` `kernel` and `bias`.
def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
  n = kernel_size + bias_size
  c = np.log(np.expm1(1.))
  return tf.keras.Sequential([
      tfp.layers.VariableLayer(2 * n, dtype=dtype),
      tfp.layers.DistributionLambda(lambda t: tfd.Independent(
          tfd.Normal(loc=t[..., :n],
                     scale=1e-5 + 0.01*tf.nn.softplus(c + t[..., n:])),
          reinterpreted_batch_ndims=1)),
  ])

但请注意,我在Softplus前面加了0.01,降低了标准差的大小。试试这个。

比这更好的是使用采样初始化,就像DenseFlipout中默认使用的那样。https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DenseFlipout?version=nightly

这里是与DenseVariational配合使用的同样的初始化器:

def random_gaussian_initializer(shape, dtype):
    n = int(shape / 2)
    loc_norm = tf.random_normal_initializer(mean=0., stddev=0.1)
    loc = tf.Variable(
        initial_value=loc_norm(shape=(n,), dtype=dtype)
    )
    scale_norm = tf.random_normal_initializer(mean=-3., stddev=0.1)
    scale = tf.Variable(
        initial_value=scale_norm(shape=(n,), dtype=dtype)
    )
    return tf.concat([loc, scale], 0)

现在,您只需更改后验平均场中的VariableLayer即可。
tfp.layers.VariableLayer(2 * n, dtype=dtype, initializer=lambda shape, dtype: random_gaussian_initializer(shape, dtype), trainable=True)

您现在从均值为-3,标准差为0.1的正态分布中进行取样,以输入到您的软加卷积神经网络中。使用后验平均场的均值计算,我们得到规模为Softplus(-3)=0.048587352,因此它非常小。通过取样,我们将初始化所有规模不同,但大致在该均值附近。


感谢您的建议@Perd。我无法使用tanh激活使其工作。但是,使用relu和进一步的微调,我能够获得良好的结果。 - Vijay Giri
你可能还想尝试一种非均场方法,使用完整协方差。这将需要更长的收敛时间,但应该更加灵活,因为它包括分布之间的相关性。因此,由于协方差矩阵是对称的,tfp.distributions.MultivariateNormalTriL 应该是正确的选择。 - Perd

0

我曾经也遇到过同样的问题,花了一段时间才意识到原因。

你在 Dense-NN 中的最后一层没有激活函数(tf.keras.layers.Dense(1)),而在 Variational-NN 中的最后一层使用了 tanh 作为激活函数(tfp.layers.DenseVariational( 1, activation='tanh' ...))。删除这个应该可以解决问题。 我还观察到,在这种情况下,relu 和尤其是 leaky-relu 比 tanh 更优秀。


谢谢@bayes2021。发现得真好。已经尝试过了。不过,用tanh仍然无法获得良好的结果。 - Vijay Giri
@VijayGiri,也许snake激活函数会很有趣。链接。我还遇到了DenseVariational层与周期函数以及几乎所有激活函数结合使用的问题。 - Fermat

0

根据@Perd的答案,我在后验中尝试了更低的标准偏差。

对于这个数据和NN架构,在使用tanh激活函数时,我无法获得更好的结果。但是,当使用relu激活函数和scale=1e-5 + 0.001 * tf.nn.softplus(c + t[..., n:])时,我能够得到最佳结果。

该模型似乎非常敏感于超参数。以下是不同后验scale值的结果

对于scale = 1e-5 + 0.01 * tf.nn.softplus(c + t[..., n:])的情况 0.01

对于scale = 1e-5 + 0.005 * tf.nn.softplus(c + t[..., n:])的情况 0.005

对于 scale=1e-5 + 0.002 * tf.nn.softplus(c + t[..., n:])) 0.002

对于 scale=1e-5 + 0.0015 * tf.nn.softplus(c + t[..., n:])) 0.0015

对于 scale=1e-5 + 0.001 * tf.nn.softplus(c + t[..., n:])) 0.001

对于 tanh 激活函数,仍然无法得到良好的结果 tanh

代码链接


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接