一个多输出深度学习模型是如何训练的?

8
我觉得我不理解多输出网络。
虽然我了解其实现方式,并且成功地训练了一个这样的模型,但我不理解多输出深度学习网络是如何训练的。我的意思是,在训练期间,网络内部会发生什么?
keras functional api guide中的这个网络为例:

enter image description here

您可以看到两个输出(aux_output和main_output)。反向传播是如何工作的?
我的直觉是模型会进行两次反向传播,一次针对每个输出。然后,每个反向传播都会更新前面层的权重。 但事实并非如此:从here(SO)得到的信息表明,尽管有多个输出,但只有一个反向传播;使用的损失根据输出进行加权。
但仍然不清楚网络及其辅助分支如何训练;辅助分支的权重是如何更新的,因为它没有直接连接到主输出?在辅助分支的根部和主输出之间的网络部分是否受到损失加权的影响?还是加权只影响与辅助输出相连的网络部分?

另外,我正在寻找有关此主题的好文章。我已经阅读了有关GoogLeNet / Inception的文章(v1, v2-v3),因为这个网络使用辅助分支。


这篇文章可能会对你有所帮助。它展示了反向传播背后的数学原理,让你明白它其实很简单,只是被人们复杂化了。 - omoshiroiii
1
感谢您的提问,但它更关注机器学习的理论方面,因此不适合在SO上发布。您可以在CrossValidated上获得更好的答案机会。此外,在SO上请求文章、教程或站外资源是不允许的。 - today
我理解您的观点。这个问题与理解Keras函数API的工作方式密切相关,因为这种架构是其主要示例。所以我认为在这里发布这个问题是相关的。如果我在赏金结束后没有找到答案,我会将问题移动到CrossValidated! - Baptiste Pouthier
1个回答

6

Keras的计算基于图形,并且只使用一个优化器

优化器也是图形的一部分,在其计算中,它会获取整个权重组的梯度。(不是针对每个输出的两个梯度组,而是针对整个模型的一个梯度组)。

从数学上讲,这并不太复杂,你有一个由下列内容组成的最终损失函数:

loss = (main_weight * main_loss) + (aux_weight * aux_loss) #you choose the weights in model.compile

你定义了所有的权重,以及一系列其他可能的权重(样本权重、类别权重、正则化项等)。
其中:
- `main_loss` 是一个由 `main_true_output_data` 和 `main_model_output` 构成的函数。 - `aux_loss` 是一个由 `aux_true_output_data` 和 `aux_model_output` 构成的函数。
梯度仅为所有权重的 `loss` 对权重 `weight_i` 的偏导数。
优化器获取梯度后,只执行一次优化步骤。 问题:

如何更新辅助分支的权重,因为它没有直接连接到主输出?

你有两个输出数据集:一个用于 `main_output`,另一个用于 `aux_output`。你必须将它们传递给 `model.fit(inputs, [main_y, aux_y], ...)` 中的 `fit`。
你还有两个损失函数,一种用于每个输出,其中 `main_loss` 采用 `main_y` 和 `main_out`,而 `aux_loss` 采用 `aux_y` 和 `aux_out`。
这两个损失相加,得到 `loss = (main_weight * main_loss) + (aux_weight * aux_loss)`。
梯度计算是针对函数 `loss` 进行的,而此函数连接到整个模型。
- `aux` 项将影响反向传播中的 `lstm_1` 和 `embedding_1`。 - 因此,在下一次前向传递(权重更新后)中,它将影响主分支。 (它是否更好还是更差只取决于辅助输出是否有用)

对于根据辅助分支和主要输出连接的网络部分,损失的加权考虑了吗?或者说加权只影响与辅助输出相连的网络部分?

权重是纯数学计算。你会在 `compile` 中定义它们:
model.compile(optimizer=one_optimizer, 

              #you choose each loss   
              loss={'main_output':main_loss, 'aux_output':aux_loss},

              #you choose each weight
              loss_weights={'main_output': main_weight, 'aux_output': aux_weight}, 

              metrics = ...)

并且损失函数将在 loss = (weight1 * loss1) + (weight2 * loss2) 中使用它们。
其余部分是每个权重的数学计算 ∂(loss)/∂(weight_i)


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接