Pytorch图像分类实现工业缺陷检测
In [1]:
import os, sys
import torch, torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader
import cv2 as cv
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from tqdm import tqdm
In [2]:
# CR 裂纹 crackle
# In 夹杂 inclusion
# SC 划痕 scratch
# PS 压入氧化皮 press in oxide scale
# RS 麻点
# PA 斑点
defect_labels = ['In', 'Sc', 'Cr', 'PS', 'RS', 'Pa']
num_classes = len(defect_labels)
In [3]:
class SurfaceDefectDataset(Dataset):
def __init__(self, root_dir):
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((200, 200))])
img_files = os.listdir(root_dir)
self.defect_types = []
self.images = []
for file_name in img_files:
# 以下划线分割文件名
defect_class = file_name.split('_')[0]
defect_index = defect_labels.index(defect_class)
self.images.append(os.path.join(root_dir, file_name))
self.defect_types.append(defect_index)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = self.images[idx]
img = cv.imread(image_path) # BGR
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
sample = {'image': self.transform(img), 'defect': self.defect_types[idx]}
return sample
In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
In [5]:
class SurfaceDectectResNet(torch.nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.cnn_layers = torchvision.models.resnet18(weights='DEFAULT')
in_features = self.cnn_layers.fc.in_features
self.cnn_layers.fc = torch.nn.Linear(in_features, num_classes)
def forward(self, x):
return self.cnn_layers(x)
In [6]:
# 创建训练集
train_dataset = SurfaceDefectDataset('./enu_surface_defect/train')
train_num = len(train_dataset)
# 类别和index的对应关系, 写入文件.
classify_dict = dict((i, label) for i, label in enumerate(defect_labels))
batch_size = 32
# 把dataset变成dataloader
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True)
validate_dataset = SurfaceDefectDataset('./enu_surface_defect/test')
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size,
shuffle=True)
net = SurfaceDectectResNet(num_classes=num_classes)
net.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
epochs = 8
save_path = './model.pth'
best_acc = 0.0
train_steps = len(train_loader)
# 训练过程
for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data['image'], data['defect']
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_fn(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = f'train epoch[{epoch+1}/{epochs}] loss:{loss:.3f}'
# 校验代码
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data['image'], val_data['defect']
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accuracy = acc / val_num
print(f'[epoch {epoch+1} train_loss: {running_loss/train_steps:.3f},'
f'val_accuracy:{val_accuracy:.3f}')
if val_accuracy > best_acc:
best_acc = val_accuracy
torch.save(net.state_dict(), save_path)
train epoch[1/8] loss:0.624: 100%|█████████████████████████████████████████████████████| 56/56 [00:42<00:00, 1.30it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 5.95it/s] [epoch 1 train_loss: 0.182,val_accuracy:1.000 train epoch[2/8] loss:0.039: 100%|█████████████████████████████████████████████████████| 56/56 [00:42<00:00, 1.30it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.41it/s] [epoch 2 train_loss: 0.022,val_accuracy:1.000 train epoch[3/8] loss:0.717: 100%|█████████████████████████████████████████████████████| 56/56 [00:42<00:00, 1.32it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.02it/s] [epoch 3 train_loss: 0.025,val_accuracy:1.000 train epoch[4/8] loss:0.029: 100%|█████████████████████████████████████████████████████| 56/56 [00:44<00:00, 1.25it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.27it/s] [epoch 4 train_loss: 0.028,val_accuracy:1.000 train epoch[5/8] loss:0.061: 100%|█████████████████████████████████████████████████████| 56/56 [00:42<00:00, 1.31it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.13it/s] [epoch 5 train_loss: 0.007,val_accuracy:1.000 train epoch[6/8] loss:0.220: 100%|█████████████████████████████████████████████████████| 56/56 [00:43<00:00, 1.30it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.21it/s] [epoch 6 train_loss: 0.008,val_accuracy:1.000 train epoch[7/8] loss:0.015: 100%|█████████████████████████████████████████████████████| 56/56 [00:43<00:00, 1.29it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.15it/s] [epoch 7 train_loss: 0.009,val_accuracy:1.000 train epoch[8/8] loss:0.002: 100%|█████████████████████████████████████████████████████| 56/56 [00:42<00:00, 1.33it/s] 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.04it/s] [epoch 8 train_loss: 0.005,val_accuracy:1.000
In [7]:
data_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((200, 200))])
img_path = './enu_surface_defect/PS_185.bmp'
img = cv.imread(img_path)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
model = SurfaceDectectResNet(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(save_path, weights_only=True))
model.eval()
with torch.no_grad():
output = model(img.to(device))
output = torch.squeeze(output).cpu()
predict = torch.softmax(output, dim=0)
predict_class = int(torch.argmax(predict).numpy())
print_res = f'class: {classify_dict[predict_class]}, prob:{predict[predict_class].numpy()}'
plt.title(print_res)
plt.show()


浙公网安备 33010602011771号