1 #coding= utf-8
2 import os
3 import torch
4 from data_pipe import get_data
5 from model import SimpleNet
6 import numpy as np
7 import cv2
8 from PIL import Image
9
10
11 class Infer(object):
12
13 def __init__(self):
14 self.model = SimpleNet()
15 self.model.load_state_dict(torch.load("./models/model_10.pth"))
16 self.model.eval()
17
18 def _infer(self, img_tensor):
19 with torch.no_grad():
20 result = self.model(img_tensor)
21 if result > 0.5:
22 result = 1
23 else:
24 result = 0
25 return result
26
27 def predict(self, path):
28 img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
29 for img_path in img_path_list:
30 print(img_path)
31 img = cv2.imread(img_path)
32 img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
33 img_tensor = torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0
34 img_tensor = img_tensor.reshape((1, 3, 32, 32))
35 result = self._infer(img_tensor)
36 print(result)
37
38
39 if __name__ == "__main__":
40 path = "./test_images"
41 Infer().predict(path)