import os
import mmcv
import time
import torch
import numpy as np
import cv2
import PIL
from mmcv import Config
from mmcls.apis import inference_model, init_model, show_result_pyplot, train_model, set_random_seed, single_gpu_test
from mmcls.models import build_classifier
from mmcls.datasets import build_dataloader,build_dataset
from mmcv.runner import load_checkpoint
from mmcv.parallel import collate, scatter
from mmcls.datasets.pipelines import Compose
class MMClassification:
def sota(self):
pypath = os.path.abspath(__file__)
father = os.path.dirname(pypath)
models = os.path.join(father, 'models')
sota_model = []
for i in os.listdir(models):
if i[0] != '_':
sota_model.append(i)
return sota_model
def __init__(
self,
backbone='LeNet',
num_classes=-1,
dataset_path='../dataset/cls/hand_gray',
# dataset_type = 'ImageNet'
**kwargs,
):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: " + next(iter(kwargs.keys()))
raise Exception(info)
# 获取外部运行py的绝对路径
self.cwd = os.path.dirname(os.getcwd())
# 获取当前文件的绝对路径
self.file_dirname = os.path.dirname(os.path.abspath(__file__))
self.save_fold = None
if backbone not in self.sota():
if os.path.exists(backbone): # 传入配置文件
self.config = backbone
self.cfg = Config.fromfile(self.config)
self.backbone = backbone
else:
info = "Error Code: -302. No such argument: "+backbone +". Currently "+str(self.sota())+" is available."
# print(info)
raise Exception(info)
elif backbone in self.sota():
self.config = os.path.join(self.file_dirname, 'models', 'LeNet/LeNet.py')
self.checkpoint = os.path.join(self.file_dirname, 'models', 'LeNet/LeNet.pth')
self.backbone = backbone
backbone_path = os.path.join(self.file_dirname, 'models', self.backbone)
ckpt_cfg_list = list(os.listdir(backbone_path))
for item in ckpt_cfg_list:
if item[-1] == 'y' and item[0] != '_': #pip修改1
self.config = os.path.join(backbone_path, item)
elif item[-1] == 'h':
self.checkpoint = os.path.join(backbone_path, item)
else:
# print("Warning!!! There is an unrecognized file in the backbone folder.")
pass
self.cfg = Config.fromfile(self.config)
self.dataset_path = dataset_path
self.lr = None
self.backbonedict = {
'MobileNet': os.path.join(self.file_dirname, 'models', 'MobileNet/MobileNet.py'),
'ResNet18': os.path.join(self.file_dirname, 'models', 'ResNet18/ResNet18.py'),
'ResNet50': os.path.join(self.file_dirname, 'models', 'ResNet50/ResNet50.py'),
'LeNet': os.path.join(self.file_dirname, 'models', 'LeNet/LeNet.py'),
# 下略
}
self.num_classes = num_classes
self.chinese_res = None
self.is_sample = False
self.image_type = ""
def train(self, random_seed=0, save_fold=None, distributed=False, validate=True, device="cpu",
metric='accuracy', save_best='auto', optimizer="SGD", epochs=100, lr=0.01, weight_decay=0.001,
checkpoint=None,batch_size=None,**kwargs):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: " + next(iter(kwargs.keys()))
raise Exception(info)
if device not in ['cpu','cuda']:
info = "Error Code: -301. No such argument: "+ device
raise Exception(info)
is_cuda = torch.cuda.is_available()
if device == 'cpu' and is_cuda:
print("You can use 'device=cuda' to accelerate !")
elif device == 'cuda' and not is_cuda:
raise Exception("Error Code: -301. Your device doesn't support cuda.")
if validate not in [True, False]:
info = "Error Code: -303. No such argument: "+ validate
raise Exception(info)
if checkpoint != None and checkpoint.split(".")[-1] != 'pth':
info = "Error Code: -202. Checkpoint file type error:"+ checkpoint
raise Exception(info)
set_random_seed(seed=random_seed)
# 获取config信息
if self.backbone.split('.')[-1] == 'py':
self.cfg = Config.fromfile(self.backbone)
else:
self.cfg = Config.fromfile(self.backbonedict[self.backbone])
# 如果外部不指定save_fold
if not self.save_fold:
# 如果外部也没有传入save_fold,我们使用默认路径
if not save_fold:
self.save_fold = os.path.join(self.cwd, 'checkpoints/cls_model')
# 如果外部传入save_fold,我们使用传入值
else:
self.save_fold = save_fold
if self.num_classes != -1:
if 'num_classes' in self.cfg.model.backbone.keys():
self.cfg.model.backbone.num_classes = self.num_classes
else:
self.cfg.model.head.num_classes = self.num_classes
self.load_dataset(self.dataset_path)
if 'val_set' not in os.listdir(self.dataset_path):
print("Unable to validate during training due to lack of validation set !")
datasets = None
try:
datasets = [build_dataset(self.cfg.data.train)]
except FileNotFoundError as err:
if not os.path.exists(self.dataset_path):
info = "Error Code: -101. No such dataset directory:" + self.dataset_path
else:
err = str(err).split(":")[-1]
info = "Error Code: -201. Dataset file type error. No such file:"+ err
raise Exception(info)
# 进行
self.cfg.work_dir = self.save_fold
# 创建工作目录
mmcv.mkdir_or_exist(os.path.abspath(self.cfg.work_dir))
# 创建分类器
model = build_classifier(self.cfg.model)
if not checkpoint:
model.init_weights()
else:
try:
load_checkpoint(model, checkpoint, map_location=torch.device(device))
# model = init_model(self.cfg, checkpoint)
except FileNotFoundError:
info = "Error Code: -102. No such checkpoint file:" + checkpoint
raise Exception(info)
# 添加类别属性以方便可视化
model.CLASSES = datasets[0].CLASSES
n_class = len(model.CLASSES)
if n_class <= 5:
self.cfg.evaluation.metric_options = {'topk': (1,)}
else:
self.cfg.evaluation.metric_options = {'topk': (5,)}
if optimizer == 'Adam':
self.cfg.optimizer = dict(type='Adam', lr=lr,betas=(0.9, 0.999),eps=1e-8, weight_decay=0.0001)
elif optimizer == 'Adagrad':
self.cfg.optimizer = dict(type='Adagrad',lr=lr, lr_decay=0)
# 根据输入参数更新config文件
self.cfg.optimizer.lr = lr # 学习率
self.cfg.optimizer.type = optimizer # 优化器
self.cfg.optimizer.weight_decay = weight_decay # 优化器的衰减权重
self.cfg.evaluation.metric = metric # 验证指标
self.cfg.evaluation.save_best = save_best #
self.cfg.runner.max_epochs = epochs # 最大的训练轮次
# 设置每 10 个训练批次输出一次日志
self.cfg.log_config.interval = 10
self.cfg.gpu_ids = range(1)
self.cfg.seed = random_seed
self.cfg.device = device
if batch_size is not None:
self.cfg.data.samples_per_gpu = batch_size
train_model(
model,
datasets,
self.cfg,
distributed=distributed,
validate=validate,
timestamp=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
device=device,
meta=dict()
)
def print_result(self, res=None):
if self.is_sample == True:
print("示例分类结果如下:")
sample_result = r"[{'标签': 2, '置信度': 1.0, '预测结果': 'scissors'}]"
print(sample_result)
else:
print("分类结果如下:")
print(self.chinese_res)
return self.chinese_res
def load_checkpoint(self,
checkpoint=None,
device='cpu',
**kwargs,
):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: "+ next(iter(kwargs.keys()))
raise Exception(info)
if device not in ['cpu','cuda']:
info = "Error Code: -301. No such argument: "+ device
raise Exception(info)
is_cuda = torch.cuda.is_available()
if device == 'cpu' and is_cuda:
print("You can use 'device=cuda' to accelerate !")
elif device == 'cuda' and not is_cuda:
raise Exception("Error Code: -301. Your device doesn't support cuda.")
if checkpoint != None and checkpoint.split(".")[-1] != 'pth':
info = "Error Code: -202. Checkpoint file type error:"+ checkpoint
raise Exception(info)
# if not checkpoint:
# checkpoint = os.path.join(self.cwd, 'checkpoints/cls_model/hand_gray/latest.pth')
self.device = device
classed_name = torch.load(checkpoint)['meta']['CLASSES']
# classed_name = self.get_class(class_path)
# self.class_path = class_path
self.num_classes = len(classed_name)
if self.num_classes != -1:
if 'num_classes' in self.cfg.model.backbone.keys():
self.cfg.model.backbone.num_classes = self.num_classes
else:
self.cfg.model.head.num_classes = self.num_classes
checkpoint = os.path.abspath(checkpoint) # pip修改2
self.checkpoint = checkpoint
try:
self.infer_model = init_model(self.cfg, checkpoint, device=device)
except FileNotFoundError:
info = "Error Code: -102. No such checkpoint file:"+ checkpoint
raise Exception(info)
self.infer_model.CLASSES = classed_name
def inference(self, device='cpu',
checkpoint=None,
image=None,
show=True,
save_fold='cls_result',
**kwargs,
):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: " + next(iter(kwargs.keys()))
raise Exception(info)
if device not in ['cpu','cuda']:
info = "Error Code: -301. No such argument: "+ device
raise Exception(info)
is_cuda = torch.cuda.is_available()
if device == 'cpu' and is_cuda:
print("You can use 'device=cuda' to accelerate !")
elif device == 'cuda' and not is_cuda:
raise Exception("Error Code: -301. Your device doesn't support cuda.")
if type(image)!=np.ndarray and image == None: # 传入图片为空,示例输出
self.is_sample = True
sample_return = """
{'pred_label': 2, 'pred_score': 0.9930743, 'pred_class': 'scissors'}
"""
return sample_return
self.is_sample = False
# if not isinstance(image,(str, np.array)):
# if not isinstance(image,str): # 传入图片格式,仅支持str图片路径
# info = "Error Code: -304. No such argument:"+ image+"which is" +type(image)
# raise Exception(info)
if type(image) != PIL.PngImagePlugin.PngImageFile and type(image) != np.ndarray and not os.path.exists(image):
info = "Error Code: -103. No such file:"+ image
raise Exception(info)
if type(image) != PIL.PngImagePlugin.PngImageFile and os.path.isfile(image) and image.split(".")[-1].lower() not in ["png","jpg","jpeg","bmp"]:
info = "Error Code: -203. File type error:"+ image
raise Exception(info)
if checkpoint != None and checkpoint.split(".")[-1] != 'pth':
info = "Error Code: -202. Checkpoint file type error:"+ checkpoint
raise Exception(info)
if not checkpoint:
checkpoint = os.path.join(self.cwd, 'checkpoints/cls_model/hand_gray/latest.pth')
checkpoint = os.path.abspath(checkpoint) # pip修改2
self.load_checkpoint(device= device, checkpoint=os.path.abspath(checkpoint))
return self.fast_inference(image=image, show=show,save_fold=save_fold, **kwargs)
def fast_inference(self, image, show=False, save_fold='cls_result',**kwargs):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: " + next(iter(kwargs.keys()))
raise Exception(info)
# img_array = mmcv.imread(image, flag='color')
try:
self.infer_model
except:
print("请先使用load_checkpoint()方法加载权重!")
return
print("========= begin inference ==========")
classed_name = self.infer_model.CLASSES
self.num_classes = len(classed_name)
results = []
dataset_path = os.getcwd()
if type(image) == PIL.PngImagePlugin.PngImageFile: # 以PIL读入图片
self.image_type = "pil"
image = np.array(image)
if type(image) == np.ndarray:
if self.cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
self.cfg.data.test.pipeline.pop(0)
if self.image_type != "pil": self.image_type = "numpy"
print("{} image".format(self.image_type))
if self.backbone != "LeNet": # 单张图片 其他网络
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
img_array = mmcv.imread(image, flag='color')
result = inference_model(self.infer_model, img_array) # 此处的model和外面的无关,纯局部变量
else:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
data = dict(img=img_gray)
test_pipeline = Compose(self.cfg.data.test.pipeline)
data = test_pipeline(data)
# data = scatter([data], [self.device])[0]
data_loader = build_dataloader(
[data,data],
samples_per_gpu=self.cfg.data.samples_per_gpu,
workers_per_gpu=self.cfg.data.workers_per_gpu,
shuffle=False,
round_up=True)
result = self.batch_infer(self.infer_model, data_loader)
ff = classed_name
pred_class = ff[np.argmax(result[0])] if ff[np.argmax(result[0])][-1:] != "\n" else ff[np.argmax(
result[0])][:-1]
result = {
'pred_label': np.argmax(result[0]),
'pred_score': result[0][np.argmax(result[0])],
'pred_class': pred_class,
}
self.infer_model.show_result(image, result, show=show, out_file=os.path.join(save_fold, "{}img.jpg".format(self.image_type)))
chinese_res = []
tmp = {}
if isinstance(result['pred_label'], np.int64):
result['pred_label'] = int(result['pred_label'])
if isinstance(result['pred_score'], np.float32):
result['pred_score'] = float(result['pred_score'])
tmp['标签'] = result['pred_label']
tmp['置信度'] = result['pred_score']
tmp['预测结果'] = result['pred_class']
chinese_res.append(tmp)
self.chinese_res = chinese_res
elif os.path.isfile(image): # 以路径读入图片
if self.backbone != "LeNet": # 单张图片 其他网络
img_array = mmcv.imread(image, flag='color')
result = inference_model(self.infer_model, img_array) # 此处的model和外面的无关,纯局部变量
else: # 单张图片 Lenet
imagename = image.split("/")[-1]
# build the dataloader
f = open("test.txt", 'w')
f.write(imagename)
f.write(" 1")
f.write('\n')
f.write("no.png 0")
f.close()
if not os.path.exists("cache"):
os.mkdir('cache')
import shutil
if not os.path.exists(os.path.join("cache", imagename)):
shutil.copyfile(image, os.path.join("cache", imagename))
shutil.copyfile(image, os.path.join("cache", "no.png"))
self.cfg.data.test.data_prefix = os.path.join(dataset_path, 'cache')
self.cfg.data.test.ann_file = os.path.join(dataset_path, 'test.txt')
# self.cfg.data.test.classes = os.path.abspath(self.class_path)
dataset = build_dataset(self.cfg.data.test)
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=self.cfg.data.samples_per_gpu,
workers_per_gpu=self.cfg.data.workers_per_gpu,
shuffle=False,
round_up=True)
result = self.batch_infer(self.infer_model, data_loader)
os.remove("test.txt")
shutil.rmtree("cache")
ff = classed_name
pred_class = ff[np.argmax(result[0])] if ff[np.argmax(result[0])][-1:] != "\n" else ff[np.argmax(
result[0])][:-1]
result = {
'pred_label': np.argmax(result[0]),
'pred_score': result[0][np.argmax(result[0])],
'pred_class': pred_class,
}
self.infer_model.show_result(image, result, show=show, out_file=os.path.join(save_fold, os.path.split(image)[1]))
chinese_res = []
tmp = {}
if isinstance(result['pred_label'], np.int64):
result['pred_label'] = int(result['pred_label'])
if isinstance(result['pred_score'], np.float32):
result['pred_score'] = float(result['pred_score'])
tmp['标签'] = result['pred_label']
tmp['置信度'] = result['pred_score']
tmp['预测结果'] = result['pred_class']
# img.append(tmp)
chinese_res.append(tmp)
# print(chinese_res)
self.chinese_res = chinese_res
print("\n========= finish inference ==========")
return result
else:
if self.backbone != "LeNet": # 文件夹 其他网络
f = open("test.txt", 'w')
for image_name in os.listdir(image):
f.write(image_name)
f.write(" 1")
f.write('\n')
f.close()
self.cfg.data.test.data_prefix = image
self.cfg.data.test.ann_file = os.path.join(dataset_path, 'test.txt')
# self.cfg.data.test.classes = os.path.abspath(self.class_path)
dataset = build_dataset(self.cfg.data.test)
os.remove("test.txt")
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=self.cfg.data.samples_per_gpu,
workers_per_gpu=self.cfg.data.workers_per_gpu,
shuffle=False,
round_up=True)
else: # 文件夹 Lenet
dirname = [x.strip() for x in image.split('/') if x.strip() != ''][-1]
import shutil
if os.path.exists(os.path.join(dataset_path, 'cache')):
shutil.rmtree("cache")
os.mkdir(os.path.join(dataset_path, 'cache'))
shutil.copytree(image, os.path.join(dataset_path, 'cache', dirname))
for i in range(len(classed_name) - 1):
dummy_folder = os.path.join(dataset_path, 'cache', 'dummy' + str(i))
os.mkdir(dummy_folder)
self.cfg.data.test.data_prefix = os.path.join(dataset_path, 'cache')
# self.cfg.data.test.classes = os.path.abspath(self.class_path)
dataset = build_dataset(self.cfg.data.test)
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=self.cfg.data.samples_per_gpu,
workers_per_gpu=self.cfg.data.workers_per_gpu,
shuffle=False,
round_up=True)
results_tmp = self.batch_infer(self.infer_model, data_loader)
if os.path.exists(os.path.join(dataset_path, 'cache')):
shutil.rmtree("cache")
results = []
for i in range(len(results_tmp)):
pred_class = classed_name[np.argmax(results_tmp[i])] if classed_name[np.argmax(results_tmp[i])][-1:] != "\n" else classed_name[ np.argmax(results_tmp[i])][:-1]
if isinstance(np.argmax(results_tmp[i]), np.int64):
pred_label = int(np.argmax(results_tmp[i]))
if isinstance(results_tmp[i][np.argmax(results_tmp[i])], np.float32):
pred_score = float(results_tmp[i][np.argmax(results_tmp[i])])
tmp_result = {
'pred_label': pred_label, # np.argmax(result[i]),
'pred_score': pred_score, # result[i][np.argmax(result[i])],
'pred_class': pred_class,
}
results.append(tmp_result)
for i, img in enumerate(os.listdir(image)):
self.infer_model.show_result(os.path.join(image,img), results[i], out_file=os.path.join(save_fold, os.path.split(img)[1]))
# model.show_result(image, result, show=show, out_file=os.path.join(save_fold, os.path.split(image)[1]))
chinese_res = []
for i in range(len(results)):
tmp = {
'标签': results[i]['pred_label'],
'置信度': results[i]['pred_score'],
'预测结果': results[i]['pred_class']
}
# img.append(tmp)
chinese_res.append(tmp)
# print(chinese_res)
self.chinese_res = chinese_res
print("\n========= finish inference ==========")
return results
def load_dataset(self, path,**kwargs):
if len(kwargs) != 0:
info = "Error Code: -501. No such parameter: "+ next(iter(kwargs.keys()))
raise Exception(info)
self.dataset_path = path
if not isinstance(path, str):
info = "Error Code: -201. Dataset file type error, which should be <class 'str'> instead of "+ type(path)+"."
raise Exception(info)
if not os.path.exists(path): # 数据集路径不存在
info = "Error Code: -101. No such dateset directory: "+ path
raise Exception(info)
val_set = os.path.join(path, 'val_set')
val_txt = os.path.join(path, 'val.txt')
if os.path.exists(val_set) and os.path.exists(val_txt):
val_num = 0
for i in os.listdir(val_set):
val_num += len(os.listdir(os.path.join(val_set,i)))
if val_num != len(open(val_txt).readlines()):
info = "Error Code: -201. Dataset file type error. The number of val set images does not match that in val.txt"
raise Exception(info)
self.cfg.img_norm_cfg = dict(
mean=[124.508, 116.050, 106.438],
std=[58.577, 57.310, 57.437],
to_rgb=True
)
tot = 0
for root, dirs, files in os.walk(path):
for file in files:
if "txt" not in file:
impath = os.path.join(root, file)
# print("impath", impath)
if "jpg" in impath or "png" in impath:
# print(impath)
img = cv2.imread(impath)
try:
img.shape
except AttributeError:
info = "Error Code: -201. The image file {} is damaged.".format(impath)
raise Exception(info)
tot += 1
val_permit = True
if 'val.txt' not in os.listdir(path) and 'val_set' in os.listdir(path):
# print("生成val.txt")
val_permit, valtxt_path = self.generate_txt(path, "val")
# test_permit = True
# if 'test.txt' not in os.listdir(path) and 'test_set' in os.listdir(path):
# print("生成test.txt")
# test_permit, testtxt_path = self.generate_txt(path, "test")
class_permit = True
if 'classes.txt' not in os.listdir(path):
# print("生成classes.txt")
training_set = os.path.join(path, 'training_set')
content = sorted(os.listdir(training_set))
content = [i+"\n" for i in content]
try:
classestxt = open(os.path.join(path,"classes.txt"), mode='w')
except:
class_permit = False
dataset_txt = "dataset_txt"
if not os.path.exists(dataset_txt):
os.mkdir(dataset_txt)
classtxt_path = os.path.join(dataset_txt,"classes.txt")
classestxt = open(classtxt_path, mode='w')
classestxt.writelines(content)
classestxt.close()
self.cfg.data.train.data_prefix = os.path.join(self.dataset_path, 'training_set')
# self.cfg.data.train.classes = os.path.join(self.dataset_path, 'classes.txt')
self.cfg.data.val.data_prefix = os.path.join(self.dataset_path, 'val_set')
if val_permit:
self.cfg.data.val.ann_file = os.path.join(self.dataset_path, 'val.txt')
else:
self.cfg.data.val.ann_file = valtxt_path
# self.cfg.data.val.classes = os.path.join(self.dataset_path, 'classes.txt')
self.cfg.data.test.data_prefix = os.path.join(self.dataset_path, 'test_set')
# if test_permit:
self.cfg.data.test.ann_file = os.path.join(self.dataset_path, 'test.txt')
# else:
# self.cfg.data.test.ann_file = testtxt_path
# self.cfg.data.test.classes = os.path.join(self.dataset_path, 'classes.txt')
if class_permit:
self.cfg.data.train.classes = os.path.join(self.dataset_path, 'classes.txt')
self.cfg.data.val.classes = os.path.join(self.dataset_path, 'classes.txt')
self.cfg.data.test.classes = os.path.join(self.dataset_path, 'classes.txt')
else:
self.cfg.data.train.classes = classtxt_path
self.cfg.data.val.classes = classtxt_path
self.cfg.data.test.classes = classtxt_path
def generate_txt(self, path, type):
permit = True
val_set = os.path.join(path, type+'_set')
txt_path = os.path.join(path,type+".txt")
try:
valtxt = open(txt_path, mode='w')
except:
permit = False
dataset_txt = "dataset_txt"
if not os.path.exists(dataset_txt):
os.mkdir(dataset_txt)
txt_path = os.path.join(dataset_txt,type+".txt")
valtxt = open(txt_path, mode='w')
content = []
for label, i in enumerate(sorted(os.listdir(val_set))):
for j in sorted(os.listdir(os.path.join(val_set,i))):
# print(os.path.join(i,j), label)
content.append("{} {}\n".format(os.path.join(i,j), label))
valtxt.writelines(content)
valtxt.close()
return permit, txt_path
def get_class(self, class_path):
classes = []
with open(class_path, 'r') as f:
for name in f:
classes.append(name.strip('\n'))
return classes
def batch_infer(self, model, data_loader):
results_tmp = []
model.eval()
results = []
dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(task_num=len(dataset),start=False)
from mmcv.utils.timer import Timer
prog_bar.file.flush()
prog_bar.timer = Timer()
for i, data in enumerate(data_loader):
# data = data.to(device)
if self.device == "cuda": data = scatter(data, [self.device])[0]
with torch.no_grad():
result = model(return_loss=False, **data)
batch_size = len(result)
results_tmp.extend(result)
batch_size = data['img'].size(0)
for _ in range(batch_size):
prog_bar.update()
return results_tmp
def convert(self, checkpoint=None, backend="ONNX", out_file="convert_model.onnx"):
if not (backend == "ONNX" or backend == 'onnx'):
print("Sorry, we only suport ONNX up to now.")
return
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
classes_list = state_dict['meta']['CLASSES']
self.num_classes = len(classes_list)
if self.backbone == 'LeNet':
from mmcls.models.backbones import LeNet5
from collections import OrderedDict
model = LeNet5(num_classes=self.num_classes)
class LeNet5_SoftMax(LeNet5):
def forward(self, x):
x = self.features(x)
if self.num_classes > 0:
x = self.classifier(x.squeeze())
x = torch.softmax(x, dim=0)
return (x, )
model = LeNet5_SoftMax(num_classes=self.num_classes)
new_state_dict = OrderedDict()
for key in state_dict['state_dict']:
new_state_dict[key[9:]] = state_dict['state_dict'][key]
model.load_state_dict(new_state_dict)
dummy_input = torch.randn(1, 1, 32, 32)
try:
torch.onnx.export(model, dummy_input, out_file)
print(f'Successfully exported ONNX model: {out_file}')
except:
print('Please use the checkpoint train by MMEdu')
else:
ashape = [224,224]
if len(ashape) == 1:
input_shape = (1, 3, ashape[0], ashape[0])
elif len(ashape) == 2:
input_shape = (
1,
3,
) + tuple(ashape)
else:
raise ValueError('invalid input shape')
self.cfg.model.pretrained = None
self.cfg.model.head.num_classes = self.num_classes
# build the model and load checkpoint
classifier = build_classifier(self.cfg.model)
if checkpoint:
load_checkpoint(classifier, checkpoint, map_location='cpu')
else:
load_checkpoint(classifier, self.checkpoint, map_location='cpu')
pytorch2onnx(
classifier, # 模型,此处是分类器
input_shape,
output_file=out_file,
do_simplify = False,
verify =False)
with open(out_file.replace(".onnx", ".py"), "w+") as f:
gen0 = """
import onnxruntime as rt
import BaseData
import numpy as np
import cv2
tag =
"""
gen1 = """
sess = rt.InferenceSession('
"""
gen2 = """', None)
input_name = sess.get_inputs()[0].name
out_name = sess.get_outputs()[0].name
cap = cv2.VideoCapture(0)
ret_flag,Vshow = cap.read()
dt = BaseData.ImageData(Vshow, backbone="
"""
gen3 = """")
input_data = dt.to_tensor()
pred_onx = sess.run([out_name], {input_name: input_data})
ort_output = pred_onx[0]
idx = np.argmax(ort_output, axis=1)[0]
print('result:' + tag[idx])
"""
# if class_path != None:
gen = gen0.strip("\n") + str(classes_list)+ "\n" + gen1.strip("\n")+out_file+ gen2.strip("\n") + str(self.backbone) + gen3
# else:
# gen = gen0.strip("tag = \n") + "\n\n" + gen1.strip("\n")+out_file+ gen2.strip("\n") + str(self.backbone) + gen3.replace("tag[idx]", "idx")
f.write(gen)
# 模型部署
def _demo_mm_inputs(input_shape, num_classes):
"""Create a superset of inputs needed to run test or train batches.
Args:
input_shape (tuple):
input batch dimensions
num_classes (int):
number of semantic classes
"""
(N, C, H, W) = input_shape
rng = np.random.RandomState(0)
imgs = rng.rand(*input_shape)
gt_labels = rng.randint(
low=0, high=num_classes, size=(N, 1)).astype(np.uint8)
mm_inputs = {
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
'gt_labels': torch.LongTensor(gt_labels),
}
return mm_inputs
def pytorch2onnx(model,
input_shape,
opset_version=11,
dynamic_export=False,
show=False,
output_file='tmp.onnx',
do_simplify=False,
verify=False):
"""Export Pytorch model to ONNX model and verify the outputs are same
between Pytorch and ONNX.
Args:
model (nn.Module): Pytorch model we want to export.
input_shape (tuple): Use this input shape to construct
the corresponding dummy input and execute the model.
opset_version (int): The onnx op version. Default: 11.
show (bool): Whether print the computation graph. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between Pytorch and ONNX.
Default: False.
"""
from functools import partial
import onnxruntime as rt
from mmcv.onnx import register_extra_symbolics
model.cpu().eval()
if hasattr(model.head, 'num_classes'):
num_classes = model.head.num_classes
# Some backbones use `num_classes=-1` to disable top classifier.
elif getattr(model.backbone, 'num_classes', -1) > 0:
num_classes = model.backbone.num_classes
else:
raise AttributeError('Cannot find "num_classes" in both head and '
'backbone, please check the config file.')
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
imgs = mm_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# replace original forward function
origin_forward = model.forward
model.forward = partial(model.forward, img_metas={}, return_loss=False)
register_extra_symbolics(opset_version)
# support dynamic shape export
if dynamic_export:
dynamic_axes = {
'input': {
0: 'batch',
2: 'width',
3: 'height'
},
'probs': {
0: 'batch'
}
}
else:
dynamic_axes = {}
with torch.no_grad():
torch.onnx.export(
model, (img_list, ),
output_file,
input_names=['input'],
output_names=['probs'],
export_params=True,
keep_initializers_as_inputs=True,
dynamic_axes=dynamic_axes,
verbose=show,
opset_version=opset_version)
print(f'Successfully exported ONNX model: {output_file}')
model.forward = origin_forward
if do_simplify:
import onnx
import onnxsim
from mmcv import digit_version
min_required_version = '0.4.0'
assert digit_version(onnxsim.__version__) >= digit_version(
min_required_version
), f'Requires to install onnxsim>={min_required_version}'
model_opt, check_ok = onnxsim.simplify(output_file)
if check_ok:
onnx.save(model_opt, output_file)
print(f'Successfully simplified ONNX model: {output_file}')
else:
print('Failed to simplify ONNX model.')
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# test the dynamic model
if dynamic_export:
dynamic_test_inputs = _demo_mm_inputs(
(input_shape[0], input_shape[1], input_shape[2] * 2,
input_shape[3] * 2), model.head.num_classes)
imgs = dynamic_test_inputs.pop('imgs')
img_list = [img[None, :] for img in imgs]
# check the numerical value
# get pytorch output
pytorch_result = model(img_list, img_metas={}, return_loss=False)[0]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)
sess = rt.InferenceSession(output_file)
onnx_result = sess.run(
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
if not np.allclose(pytorch_result, onnx_result):
raise ValueError(
'The outputs are different between Pytorch and ONNX')
print('The outputs are same between Pytorch and ONNX')