利用detection2中已经训练配置好的模型进行目标检测和分割

  在detection2中以及配置且训练好的模型存储在detection2/model_zoo中,其中./config中就存储了相应模型的数据以及配置文件。

一、导入相应的库

 1 import cv2
 2 import detectron2
 3 from detectron2.utils.logger import setup_logger
 4 # import some common libraries
 5 import numpy as np
 6 import os, json, cv2, random
 7 # import some common detectron2 utilities
 8 from detectron2 import model_zoo
 9 from detectron2.engine import DefaultPredictor
10 from detectron2.config import get_cfg
11 from detectron2.utils.visualizer import Visualizer
12 from detectron2.data import MetadataCatalog, DatasetCatalog

 

  所使用的几个包就是cv2、detection2中若干个函数

二、导入图片以及训练好的模型

1 setup_logger()
2 im = cv2.imread("/remote-home/hthuang/detectron2/demo/input.jpeg")
3 cfg = get_cfg()#获取默认的配置参数
4 # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
5 cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
6 cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
7 # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
8 cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

  其中setup_logger()是设置日志,接着通过opencv中的imread函数对图片进行读取。

  而在detection2中对模型的参数进行了一系列的默认设置,默认配置参数在detection2/config/defaults.py文件中,可以通过cfg = get_cfg()来获取默认的配置参数

  接着在这个配置cfg中添加模型的结构配置文件参数,配置函数为cfg.merge_from_file(model_zoo.get_config_file,其中COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml文件的名字也代表了这个模型的结构是怎么样的

  cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5,设置的是IOU阈值为0.5

  接着通过cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url来向模型中添加训练好的参数

三、模型预测

1 predictor = DefaultPredictor(cfg)
2 outputs = predictor(im)
3 print(outputs["instances"].pred_classes)
4 print(outputs["instances"].pred_boxes)

  通过DefaultPredictor(cfg)这个函数生成默认的预测器,之后再通过这个预测器对图片进行目标检测和分割

四、可视化

 

1 v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
2 out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
3 cv2.imwrite("b.jpg", out.get_image()[:, :, ::-1])

 

  通过detection2中自带的可视化函数Visualizer进行可视化,将图片和模型预测的目标anchor以及分类结果结合并存储下来。由于我的结果是在实验室的服务器上跑的,没有屏幕,当使用cv2.imshow()函数时将会报错

 

 

   这个时候只要将cv2.imshow函数删去,同时使用cv2.imwrite()函数保存下来再打开即可。

  最后的结果为:

 

 

posted @ 2022-02-23 16:03  maple_hx  阅读(1168)  评论(0)    收藏  举报