利用detection2中已经训练配置好的模型进行目标检测和分割
在detection2中以及配置且训练好的模型存储在detection2/model_zoo中,其中./config中就存储了相应模型的数据以及配置文件。
一、导入相应的库
1 import cv2 2 import detectron2 3 from detectron2.utils.logger import setup_logger 4 # import some common libraries 5 import numpy as np 6 import os, json, cv2, random 7 # import some common detectron2 utilities 8 from detectron2 import model_zoo 9 from detectron2.engine import DefaultPredictor 10 from detectron2.config import get_cfg 11 from detectron2.utils.visualizer import Visualizer 12 from detectron2.data import MetadataCatalog, DatasetCatalog
所使用的几个包就是cv2、detection2中若干个函数
二、导入图片以及训练好的模型
1 setup_logger() 2 im = cv2.imread("/remote-home/hthuang/detectron2/demo/input.jpeg") 3 cfg = get_cfg()#获取默认的配置参数 4 # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library 5 cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 6 cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model 7 # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well 8 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
其中setup_logger()是设置日志,接着通过opencv中的imread函数对图片进行读取。
而在detection2中对模型的参数进行了一系列的默认设置,默认配置参数在detection2/config/defaults.py文件中,可以通过cfg = get_cfg()来获取默认的配置参数
接着在这个配置cfg中添加模型的结构配置文件参数,配置函数为cfg.merge_from_file(model_zoo.get_config_file,其中COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml文件的名字也代表了这个模型的结构是怎么样的
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5,设置的是IOU阈值为0.5
接着通过cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url来向模型中添加训练好的参数
三、模型预测
1 predictor = DefaultPredictor(cfg) 2 outputs = predictor(im) 3 print(outputs["instances"].pred_classes) 4 print(outputs["instances"].pred_boxes)
通过DefaultPredictor(cfg)这个函数生成默认的预测器,之后再通过这个预测器对图片进行目标检测和分割
四、可视化
1 v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2) 2 out = v.draw_instance_predictions(outputs["instances"].to("cpu")) 3 cv2.imwrite("b.jpg", out.get_image()[:, :, ::-1])
通过detection2中自带的可视化函数Visualizer进行可视化,将图片和模型预测的目标anchor以及分类结果结合并存储下来。由于我的结果是在实验室的服务器上跑的,没有屏幕,当使用cv2.imshow()函数时将会报错

这个时候只要将cv2.imshow函数删去,同时使用cv2.imwrite()函数保存下来再打开即可。
最后的结果为:



浙公网安备 33010602011771号