![]()
1 ##完全使用本地权重,识别时根据识别准确率来确定是否绘制
2 import matplotlib.pyplot as plt
3 import torch
4 import torchvision.transforms as T
5 import torchvision
6 import cv2
7 from torchvision.io.image import read_image
8 from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights
9
10 import warnings
11 warnings.filterwarnings("ignore",category=ResourceWarning)
12 warnings.filterwarnings("ignore",category=DeprecationWarning)
13
116 img_path = "./jupyterlab/doc/ccc.jpg" ##骑着自行车的美女,任选
17 img = read_image(img_path)##用pytorch提供的io函数
18
19 weights_info = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
20 ##读本地权重文件,权重文件到pytorch网站下载
21 model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights=None, progress=False, weights_backbone=None)
22 myweights = torch.load('E:/study_2022/working_python/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')
23 model.load_state_dict(myweights)
24 model.eval()##识别工作模式
25
26 preprocess = weights_info.transforms()
27 batch = [preprocess(img)]
28 prediction = model(batch)[0]
29 labels = [weights_info.meta["categories"][i] for i in prediction["labels"]]
30 boxes = [i for i in prediction["boxes"]]
31 scores = [i for i in prediction["scores"]]
32
33 myimg = cv2.imread(img_path)
35 myimg = cv2.cvtColor(myimg, cv2.COLOR_BGR2RGB)
36 for i,score in enumerate(scores):
37 if score.item() < 0.9 : continue##舍弃准确率90%以下的
38 myimg = cv2.addWeighted(myimg, alpha=0.5, src2=myimg, beta=0.5, gamma=1)
39 ##注意:cv2这里只接受整型坐标值
40 start_point = (int(boxes[i][0]), int(boxes[i][1]))
41 end_point = (int(boxes[i][2]), int(boxes[i][3]))
42 cv2.rectangle(myimg, start_point, end_point, color = (255,0,0), thickness=3)
43 cv2.putText(myimg, labels[i], start_point, cv2.FONT_HERSHEY_SIMPLEX, 2, color = (255,0,0), thickness=3)
44 plt.figure(figsize=(7, 5))
45 plt.imshow(myimg)
46 plt.xticks([])
47 plt.yticks([])
48 plt.show()