Tensorflow,Keras:如何在具有停止梯度的Keras层中设置add_loss?

4

问题 1

我们知道可以使用tf.stop_gradient(B)来防止变量B在反向传播中被训练。但我不知道如何在特定损失函数中停止B

简单地说,假设我们的损失函数为:

loss = categorical_crossentropy + my_loss
B = tf.stop_gradient(B)

在这里,categorical_crossentropymy_loss 都依赖于 B。如果我们对 B 设置停止梯度,它们都会将 B 视为常数。

但是,我如何只针对 my_lossB 设置停止梯度,同时保持 categorical_crossentropy 不变呢?就像这样:B = tf.stop_gradient(B, myloss)

我实现的代码如下:

my_loss = ...
B = tf.stop_gradient(B)
categorical_crossentropy = ...
loss = categorical_crossentropy + my_loss

这能行吗?或者说,如何使其工作?


问题2

好的,如果Q1可以解决,我的最终问题是如何在自定义层中实现它?

具体来说,假设我们有一个自定义层,其中包括可训练权重AB以及仅适用于该层的自身损失my_loss

class My_Layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(My_Layer, self).__init__(**kwargs)
    def build(self, input_shape):
        self.w = self.add_weight(name='w', trainable=True)
        self.B = self.add_weight(name='B', trainable=True)
        my_loss = w * B
        # tf.stop_gradient(w)
        self.add_loss(my_loss)

如何使 w 只对模型损失(MSE,交叉熵等)进行训练,而使 B 仅对 my_loss 进行训练?

如果我添加 tf.stop_gradient(w),那么这会停止 w 对于 my_loss 吗?还是会停止模型的最终损失?

1个回答

3

问题 1

当你运行y = tf.stop_gradient(x)时,你创建了一个名为StopGradient的操作,其输入为x,输出为y。此操作的行为类似于恒等式,即x的值与y的值相同,除了梯度不从y流向x

如果你想让梯度只从一些损失函数流向B,你只需这样做:

B_no_grad = tf.stop_gradient(B)
loss1 = get_loss(B)  # B will be updated because of loss1
loss2 = get_loss(B_no_grad)   # B will not be updated because of loss2 

当您考虑正在构建的计算图时,事情应该变得清晰。 stop_gradient 允许您为任何张量(不仅仅是变量)创建一个“身份”节点,该节点不允许梯度流经它。

问题2

如果您使用指定字符串的模型损失(例如model.compile(loss='categorical_crossentropy', ...)),因为您无法控制其构造方式,所以我不知道如何做到这一点。但是,您可以通过使用 add_loss 添加损失或使用模型输出自己构建模型级别的损失来实现。对于前者,只需创建一些使用普通变量的损失和一些使用 *_no_grad 版本的损失,将它们全部添加到 add_loss() 中,并使用 loss=None 编译您的模型。


这里只是随意猜测,但是 https://keras.io/losses/ 上说你可以将一个函数传递给 compile 而不仅仅是一个字符串,例如 model.compile(loss=losses.mean_squared_error, optimizer='sgd') 而不是 model.compile(loss='mean_squared_error', optimizer='sgd') - HeyWatchThis
2
如果我想让从loss1到层1-2的梯度仅影响权重(而不是层3-4),并且从loss2到层3-4的梯度仅影响层3-4(而不是层1-2),该怎么办?例如,当您尝试为变分自编码器添加两个不同的自定义正则化项时,一个仅适用于编码器,另一个仅适用于解码器,但它们必须共同训练时,这正是您想要的。 - Kristof
@Kristof 你解决过这个问题吗?如何让 loss1 只影响第一层和第二层,而 loss2 只影响第三层和第四层? - Mike Martin

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接