1 import os
2 import torch
3 import numpy as np
4 from Unet import UNET
5 os.environ["CUDA_VISIBLE_DEVICE"] = ""
6
7 def main():
8 demo = Demo(model_path="/xxx.pth.tar", output="pathto/xxx.onnx")
9 demo.inference()
10 check_onnx(onnx_pth="path to xxx.onnx")
11
12
13
14 #检查onnx模型
15 def check_onnx(onnx_pth):
16 import onnx
17 #load the ONNX model
18 model = onnx.load(onnx_pth)
19 #check the IR is well formed
20 onnx.checker.check_model(model)
21 #print a human readable representation of graph
22 print(onnx.helper.printable_graph(model.graph))
23
24 class WrappedModel(torch.nn.Module):
25 def __init__(self,model):
26 super().__init__()
27 self.model =model
28
29 def forward(self,x):31 outs=self.model(x)
32 new_outs=torch.sigmoid(outs)
33 return new_outs
34
35
36 class Demo():
37 def __init__(self,model_path,output):
38 self.model_path =model_path
39 self.output_path = output
40
41 def init_torch_tensor(self):
42 self.device = 'cpu'#torch.device('cpu')
43 torch.set_default_tensor_type('torch.FloatTensor')
44 #use gpu or not
45 # if torch.cuda.is_available():
46 # self.device = torch.device('cuda')
47 # torch.set_default_tensor_type('torch.FloatTensor')
48 # else:
49 # self.device = torch.device('cpu')
50 # torch.set_default_tensor_type('torch.FloatTensor')
51
52 def init_model(self,in_channels,out_channels):
53 model = UNET(in_channels=in_channels, out_channels=out_channels).to(self.device)#to('cuda')
54 return model
55
56 def resume(self, model, path):
57 if not os.path.exists(path):
58 print("Checkpoint not found:" + path)
59 return
60 states = torch.load(path, map_location=self.device)#
61 model.load_state_dict(states["state_dict"],strict=False)#states有两个key_value"state_dict","optimizer"
62
63 model_sig = WrappedModel(model)
64 print("Resume from " + path)
65 return model_sig
66
67 def inference(self):
68 #use gpu or cpu
69 self.init_torch_tensor()
70 #加载网络模型
71 model = self.init_model(in_channels=3,out_channels=2)
72 model_sig=self.resume(model, self.model_path)
73 #设置model的模式
74 model_sig.eval()
75 #设置输入
76 img = np.random.randint(0,255, size=(512,512,3),dtype=np.uint8)
77 img = img.astype(np.float32)
78 img = img / 255#(img / 255. - 0.5)/0.5
79 img = img.transpose((2,0,1)) #C H W
80 img = torch.from_numpy(img).unsqueeze(0).float()
81 #img = torch.randn(1,3,512,512)
82 '''
83 设置动态可变维度
84 KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。
85 VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。
86 '''
87 dynamic_axes = {'input':{0: 'batch_size', 2: 'height', 3: 'width'},
88 'output': {0:'batch_size', 2: 'height', 3: 'width'}}
89 with torch.no_grad():
90 img = img.to(self.device)
91 torch.onnx.export(model_sig, img, self.output_path, input_names=['input'],
92 output_names=['output'], dynamic_axes=dynamic_axes, keep_initializers_as_inputs=False,export_params=True,
93 verbose=True, opset_version=11)
94
95 if __name__ == '__main__':
96 main()