如何使用Tensorflow Object Detection API继续训练物体检测模型?

7
我正在使用Tensorflow Object Detection API进行迁移学习,训练一个目标检测模型。具体来说,我使用model zoo中的ssd_mobilenet_v1_fpn_coco,并使用提供的示例pipeline,当然,我已经用实际链接替换了占位符,包括我的训练和评估tfrecords和标签。
使用上述pipeline,我能够成功地在约5000张图像(以及相应的边界框)上训练模型(如果相关,则主要使用Google的TPU ML引擎)。
现在,我准备了额外的约2000张图像,并希望在不从头开始重新训练的情况下,使用这些新图像继续训练我的模型(训练初始模型花费了大约6小时的TPU时间)。我该怎么做?
2个回答

7

您有两个选项,在这两个选项中,都需要更改新数据集的 train_input_reader input_path

  1. 在训练配置中指定要微调的检查点时,请指定您经过训练的模型的检查点
train_config{
    fine_tune_checkpoint: <path_to_your_checkpoint>
    fine_tune_checkpoint_type: "detection"
    load_all_detection_checkpoint_vars: true
}
  1. 只需保持相同的配置(除了train_input_reader),并使用之前模型相同的model_dir。这样,API将创建一个图形,并检查model_dir中是否已存在符合该图形的检查点。如果是这样-它将恢复并继续训练。

编辑:由于错误,fine_tune_checkpoint_type先前设置为true,而通常应该是“detection”或“classification”,在这种特定情况下应该是“detection”。感谢Krish的指出。


你是简单地添加了更多已经存在的类别示例,还是引入了新的类别?如果是新的类别,你在管道中使用了什么值作为num_classes的值?此外,你是创建了一个新的classes.pbtxt文件还是只是追加了新的类别? - Adrian Hood Sr

1
我还没有在新数据集上重新训练过目标检测模型,但是看起来在配置文件中增加训练步骤的数量train_config.num_steps并在tfrecord文件中添加图像应该就足够了。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接