当存在多个输出时,如何只训练网络中的一个输出?

29

我正在Keras中使用多输出模型。

model1 = Model(input=x, output=[y2, y3])

model1.compile((optimizer='sgd', loss=cutom_loss_function)

我的custom_loss函数是

def custom_loss(y_true, y_pred):
   y2_pred = y_pred[0]
   y2_true = y_true[0]

   loss = K.mean(K.square(y2_true - y2_pred), axis=-1)
   return loss

我只想在输出 y2 上训练网络。

当使用多个输出时,损失函数中的 y_predy_true 参数的形状/结构是什么?我能像上面那样访问它们吗?是 y_pred[0] 还是 y_pred[:,0]

3个回答

23

我只想在输出 y2 上训练网络。

根据 Keras函数API指南,您可以使用以下代码实现:

model1 = Model(input=x, output=[y2,y3])   
model1.compile(optimizer='sgd', loss=custom_loss_function,
                  loss_weights=[1., 0.0])

当使用多个输出时,损失函数中的y_pred和y_true参数的形状/结构是什么?我可以像上面那样访问它们吗?是y_pred[0]还是y_pred[:,0]?

在Keras的多输出模型中,损失函数分别应用于每个输出。伪代码如下:

loss = sum( [ loss_function( output_true, output_pred ) for ( output_true, output_pred ) in zip( outputs_data, outputs_model ) ] )

我觉得在多个输出上执行损失函数的功能不可用。可能可以通过将损失函数作为网络的一层来实现。


1
在Keras的多输出模型中,损失函数会分别应用于每个输出。我有一个类似的问题,我需要分别获取两个独立输出的y_true和y_pred值。我该如何解决这个问题? - Eka
6
除非最近框架已更改,否则最简单的解决方案是将输出连接成一个单一的损失函数,然后在那里处理它们。 - Sharapolas
@Sharapolas,你有这个语句“最简单的解决方案是将输出连接成一个单一的损失函数,然后在那里处理它们”的实际示例吗? - ihavenoidea

5

如果自定义损失函数无法应用于您要忽略的输出,例如它们具有错误的形状,那么通常接受的答案将不起作用。在这种情况下,您可以为这些输出分配一个虚拟损失函数:

labels = [labels_for_relevant_output, dummy_labels_for_ignored_output]

def dummy_loss(y_true, y_pred):
    return 0.0

model.compile(loss = [custom_loss_function, dummy_loss])
model.fit(x, labels)

请注意,有时还需要更改度量指标,以便它们指定属于哪个输出。这可以通过传递一个度量字典来完成,其中键是要映射到的层/输出名称。 - Jon Nordby

2
Sharapolas的回答是正确的。然而,有一种比使用层更好的方法来构建具有多个输出的模型的复杂相互依赖的自定义损失函数。
我知道的方法是从不调用model.compile,而只调用model._make_predict_function()。从那里开始,您可以继续通过在其中调用model.output来构建自定义优化器方法。这将为您提供所有输出,在您的情况下为[y2,y3]。在进行操作时,获取一个keras.optimizer并使用它的get_update方法,使用您的model.trainable_weights和loss。最后,返回一个带有所需输入列表(在您的情况下仅为model.input)和刚刚从optimizer.get_update调用中获得的更新的keras.function。现在,此函数替换了model.fit。
上述方法通常用于PolicyGradient算法,如A3C或PPO。以下是我试图解释的示例:https://github.com/Hyeokreal/Actor-Critic-Continuous-Keras/blob/master/a2c_continuous.py。查看build_model和critic_optimizer方法,并阅读kreas.backend.function文档以了解发生的情况。
我发现这种方法经常出现会话管理问题,并且目前似乎根本无法在tf-2.0 keras中工作。因此,如果有人知道一种方法,请告诉我。我来这里寻找一个 :)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接