我希望在Mask RCNN目标检测中仅使用一个类别"person"(连同BG,即background)。我正在使用此链接:https://github.com/matterport/Mask_RCNN 运行mask rcnn。是否有特定的方法来完成这个任务(编辑特定文件、创建额外的Python文件或仅通过过滤class_names数组中的选择)?任何方向或解决方案将不胜感激。谢谢。
我已经训练了同样的代码库用于羊。你需要做两件事:
Change the train and inference class numbers as 1 + 1 ( bg and person ):
class SheepsConfig(Config):
NAME = "sheeps"
NUM_CLASSES = 1 + 1 # background + sheep
config = SheepsConfig() # Don't forget to use this config while creating your model
config.display()
You need to create dataset to train on. You can use coco as follows:
import coco
from pycocotools.coco import COCO
ct = COCO("/YourPathToCocoDataset/annotations/instances_train2014.json")
ct.getCatIds(['sheep'])
# Sheep class' id is 20. You should run for person and use that id
COCO_DIR = "/YourPathToCocoDataset/"
# This path has train2014, annotations and val2014 files in it
# Training dataset
dataset_train = coco.CocoDataset()
dataset_train.load_coco(COCO_DIR, "train", class_ids=[20])
dataset_train.prepare()
# Validation dataset
dataset_val = coco.CocoDataset()
dataset_val.load_coco(COCO_DIR, "val", class_ids=[20])
dataset_val.prepare()
然后,只需按如下方式创建您的模型:
# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config, model_dir=MODEL_DIR)
model.load_weights(COCO_MODEL_PATH, by_name=True, exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"])
# This COCO_MODEL_PATH is the path to the mask_rcnn_coco.h5 file in this repo
model.train(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=100,
layers='heads')#You can also use 'all' to train all network.
不要忘记使用tensorflow 1.x和keras 2.1.0 :) 我可以使用这些版本进行训练。
predictor.py
文件中的run_on_opencv_image()
函数中使用以下代码。请保留HTML标签。predictions = self.coco_demo.compute_prediction(image)
top_predictions = self.coco_demo.select_top_predictions(predictions)
masks = top_predictions.get_field("mask")
boxes = top_predictions.bbox
label_indexs = top_predictions.get_field("labels").numpy()
x = np.where(label_indexs != 1) # get indexes of labels which are not person
#remove items which are not person class
masks = np.delete(masks,x, axis=0)
boxes = np.delete(boxes,x, axis=0)
label_indexs = np.delete(label_indexs,x)
labels = self.convert_label_index_to_string(label_indexs)
你可以参考由你提供的 Github 作者编写的气球示例,该示例非常精简且只包含一个类(balloons)。建议你按照这个教程进行学习: https://engineering.matterport.com/splash-of-color-instance-segmentation-with-mask-r-cnn-and-tensorflow-7c761e238b46