Tensorflow 2.0多输入自定义损失函数

16

我正在尝试优化使用以下两个损失函数的模型

def loss_1(pred, weights, logits):
    weighted_sparse_ce = kls.SparseCategoricalCrossentropy(from_logits=True)
    policy_loss = weighted_sparse_ce(pred, logits, sample_weight=advantages)

def loss_2(y_pred, y):
    return kls.mean_squared_error(y_pred, y)

然而,由于TensorFlow 2期望损失函数遵循以下形式

def fn(y_pred, y_true):
    ...

我正在使用一种解决方案来处理loss_1,其中我将predweights打包成单个张量,然后在调用model.fit时传递给loss_1,最后在loss_1中再次拆包。这不太优雅也不好用,因为predweights是不同的数据类型,所以这需要每次调用model.fit时进行额外的类型转换、打包和拆包。

此外,我知道在fit中有sample_weight参数,这有点像解决这个问题的方法。如果不是因为我正在使用两个损失函数并且只想将sample_weight应用于其中之一,这可能是可行的解决方法。即使这是一个解决办法,它能推广到其他类型的自定义损失函数吗?


所有这些都说了,我的问题简洁地说就是:

在TensorFlow 2中创建具有任意数量参数的损失函数的最佳方法是什么?

我尝试过的另一件事是传递tf.tuple,但这似乎也违反了TensorFlow对损失函数输入的要求。


使用闭包怎么样?基本上,你可以定义一个标准损失函数,我们称之为 inside_loss,它只在你的 loss_1 中使用 (y_truey_pred)。你可以将权重或对数率等任何参数传递给 loss_1。最后,你的 loss_1 将返回 inside_loss 函数。这很像我们如何自定义 Keras 损失函数。https://github.com/keras-team/keras/issues/2121 - zihaozhihao
1
@zihaozhihao 这是一个有趣的解决方案,但是当使用急切张量或 NumPy 数组作为输入时,它将无法工作。 - Jon Deaton
你是指 loss_1 的参数吗?如果是的话,我相信那应该可以工作。 - zihaozhihao
是的,对于 loss_1 是不行的,因为闭包捕获的数据在创建闭包时不可用。 - Jon Deaton
TF 2.0 要求损失函数的形式为 def fn(y_true, y_pred),即 y_true 是第一个参数。 - toliveira
3个回答

12

使用TF2中的自定义训练可以轻松解决此问题。您只需要在GradientTape上下文中计算两个组件的损失函数,然后使用生成的梯度调用优化器即可。例如,您可以创建一个名为custom_loss的函数,该函数根据每个参数计算两种损失:

def custom_loss(model, loss1_args, loss2_args):
  # model: tf.model.Keras
  # loss1_args: arguments to loss_1, as tuple.
  # loss2_args: arguments to loss_2, as tuple.
  with tf.GradientTape() as tape:
    l1_value = loss_1(*loss1_args)
    l2_value = loss_2(*loss2_args)
    loss_value = [l1_value, l2_value]
  return loss_value, tape.gradient(loss_value, model.trainable_variables)

# In training loop:
loss_values, grads = custom_loss(model, loss1_args, loss2_args)
optimizer.apply_gradients(zip(grads, model.trainable_variables))

通过这种方式,每个损失函数可以接受任意数量的急切张量,无论它们是模型的输入还是输出。每个损失函数的参数集可以不相交,就像这个例子所示。


7

进一步解释Jon的回答。如果你仍然想享受Keras模型的好处,可以扩展模型类并编写自己的自定义train_step:

from tensorflow.python.keras.engine import data_adapter

# custom loss function that takes two outputs of the model
# as input parameters which would otherwise not be possible
def custom_loss(gt, x, y):
    return tf.reduce_mean(x) + tf.reduce_mean(y)

class CustomModel(keras.Model):
    def compile(self, optimizer, my_loss):
        super().compile(optimizer)
        self.my_loss = my_loss

    def train_step(self, data):
        data = data_adapter.expand_1d(data)
        input_data, gt, sample_weight = data_adapter.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            y_pred = self(input_data, training=True)
            loss_value = self.my_loss(gt, y_pred[0], y_pred[1])

        grads = tape.gradient(loss_value, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

        return {"loss_value": loss_value}

...

model = CustomModel(inputs=input_tensor0, outputs=[x, y])
model.compile(optimizer=tf.keras.optimizers.Adam(), my_loss=custom_loss)

我刚试了你的代码...但是出现了错误"ValueError: The model cannot be compiled because it has no loss to optimize."。使用Keras 2.3.0和tensorflow 2.2.0。 - zwep
通常这意味着你要么没有传入损失函数,要么传入的损失函数没有梯度可用于优化。例如,如果你的损失函数只返回一个标量。 - Jodo

0
在 tf 1.x 中,我们有 tf.nn.weighted_cross_entropy_with_logits 函数,它允许我们通过为每个类别添加额外的正权重来权衡召回率和精确度。在多标签分类中,它应该是一个 (N,) 张量或 numpy 数组。然而,在 tf 2.0 中,我还没有找到类似的损失函数,所以我用额外的参数 pos_w_arr 编写了自己的损失函数。
from tensorflow.keras.backend import epsilon

def pos_w_loss(pos_w_arr):
    """
    Define positive weighted loss function
    """
    def fn(y_true, y_pred):
        _epsilon = tf.convert_to_tensor(epsilon(), dtype=y_pred.dtype.base_dtype)
        _y_pred = tf.clip_by_value(y_pred, _epsilon, 1. - _epsilon)
        cost = tf.multiply(tf.multiply(y_true, tf.math.log(
            _y_pred)), pos_w_arr)+tf.multiply((1-y_true), tf.math.log(1-_y_pred))
        return -tf.reduce_mean(cost)
    return fn

不确定您的意思是使用急切张量或numpy数组作为输入时无法正常工作。如果我错了,请纠正我。


1
这个应该在TF 1.x中运行,其中“pos_w_arr”不是急切的张量。 在TF 2中,“pos_w_arr”在闭包创建时不可用,因此“pos_w_arr”必须是常量。 我感兴趣的是“pos_w_arr”在批处理间变化的情况。 - Jon Deaton
如果可能的话,您可以将pos_w_arr设置为tf.keras.Input - zihaozhihao
基本上,当你拟合模型时,x=[x_data,pos_w],其中 x_datapos_w 都是 Input - zihaozhihao
在我的情况下,pos_w_arr 需要是模型的 输出。那么 tf.keras.Input 是否合适呢? - Jon Deaton

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接