Tensorflow 目标检测:继续训练

3
假设我训练了一个像ResNet这样的预训练网络,并在 pipeline.config 文件中将其设置为检测模式下的 fine_tune_checkpoint_type 属性。据我所知,这意味着我们采用模型的预训练权重,但分类和框预测头除外。此外,这意味着我们可以创建自己类型的标签,这些标签将成为我们要创建/训练的模型的分类和框预测头。
现在,假设我训练了25000个步骤,想稍后继续训练,而不让模型忘记任何东西。我是否应该在 pipeline.config 中将 fine_tune_checkpoint_type 更改为 full 以便继续训练(当然还需要加载正确的检查点文件),还是应该仍将其设置为 detection
编辑:
这基于此处找到的信息:https://github.com/tensorflow/models/blob/master/research/object_detection/protos/train.proto
  //   1. "classification": Restores only the classification backbone part of
  //        the feature extractor. This option is typically used when you want
  //        to train a detection model starting from a pre-trained image
  //        classification model, e.g. a ResNet model pre-trained on ImageNet.
  //   2. "detection": Restores the entire feature extractor. The only parts
  //        of the full detection model that are not restored are the box and
  //        class prediction heads. This option is typically used when you want
  //        to use a pre-trained detection model and train on a new dataset or
  //        task which requires different box and class prediction heads.
  //   3. "full": Restores the entire detection model, including the
  //        feature extractor, its classification backbone, and the prediction
  //        heads. This option should only be used when the pre-training and
  //        fine-tuning tasks are the same. Otherwise, the model's parameters
  //        may have incompatible shapes, which will cause errors when
  //        attempting to restore the checkpoint.

因此,分类仅提供特征提取器的分类骨干部分。这意味着模型将从网络的许多部分开始重新学习。 检测恢复整个特征提取器,但“最终结果”将被遗忘,这意味着我们可以添加自己的类别并从头开始学习这些分类。 完全恢复所有内容,甚至包括类别和框预测权重。然而,只要我们不添加或删除任何类别/标签,这就没问题了。
这正确吗?
2个回答

3

没错,你理解得很正确。
piepline.config 中设置 fine_tune_checkpoint_type: full ,以保留模型在上一个检查点学到的所有内容。


这个答案不应该被选为正确答案,因为它是错误的。从2020年开始,TF2模型会自动从之前训练的检查点中重新加载自定义数据集。除了重新启动训练过程,你不需要做任何事情,训练器会自动检测并加载你自己的检查点。只要你有一个现有的检查点,它就会起作用。选择fine_tune_checkpoint_type为full将阻止你继续训练。 - undefined

3
是的,你可以通过设置fine_tune_checkpoint_type这个变量来配置需要恢复的变量。可选项是 detection 和 classification。将其设置为 detection,基本上可以从检查点中恢复所有变量;而将其设置为 classification,则只能从 feature_extractor 范围内恢复变量(即所有后端网络中的层,如 VGG、Resnet、MobileNet 等,它们被称为特征提取器)。
更多信息请点击此处

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接