Pytorch实现基于AlexNet的花分类
In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
from PIL import Image
import os, json, sys
from tqdm import tqdm
In [2]:
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
# 提取特征部分,特征层.
self.features = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
nn.SELU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(48, 128, kernel_size=5, padding=2),
nn.SELU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(128, 192, kernel_size=3, padding=1),
nn.SELU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1),
nn.SELU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1),
nn.SELU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2) # output [128, 6, 6]
)
# 分类网络部分(神经网络)
self.classifier = nn.Sequential(
nn.Dropout(p=0.4),
nn.Linear(128 * 6 * 6, 2048),
nn.SELU(inplace=True),
nn.Dropout(p=0.3),
nn.Linear(2048, 2048),
nn.SELU(inplace=True),
nn.Linear(2048, num_classes),
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
In [3]:
image_path = './花分类项目/flower_data'
assert os.path.exists(image_path), "image path does not exist"
data_transform = {
'train': transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
'val': transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
# 创建训练dataset
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
flower_list
Out[3]:
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
In [4]:
batch_size = 32
classify_dict = dict((val, key) for key, val in flower_list.items())
classify_dict
Out[4]:
{0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
In [5]:
val_num = len(validate_dataset)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4)
net = AlexNet(num_classes=5)
In [6]:
net.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)
epochs = 10
save_path = './AlexNet.pth'
best_acc = 0.0
train_step = len(train_loader)
# 训练过程
for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = net(images.to(device))
loss = loss_fn(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = f'train epoch {epoch+1}/ {epochs} loss: {loss:.3f}'
# 校验过程
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accuracy = acc / val_num
print(f'[epoch {epoch+1}] train_loss: {running_loss/train_step:.3f}, val_accuracy:{val_accuracy:.3f}')
if val_accuracy > best_acc:
best_acc = val_accuracy
torch.save(net.state_dict(), save_path)
train epoch 1/ 10 loss: 1.083: 100%|█████████████████████████████████████████████████| 104/104 [00:29<00:00, 3.56it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.27it/s] [epoch 1] train_loss: 1.498, val_accuracy:0.478 train epoch 2/ 10 loss: 0.851: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.59it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.59it/s] [epoch 2] train_loss: 1.236, val_accuracy:0.505 train epoch 3/ 10 loss: 0.992: 100%|█████████████████████████████████████████████████| 104/104 [00:30<00:00, 3.45it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.25it/s] [epoch 3] train_loss: 1.159, val_accuracy:0.593 train epoch 4/ 10 loss: 0.681: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.64it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.05it/s] [epoch 4] train_loss: 1.100, val_accuracy:0.569 train epoch 5/ 10 loss: 0.940: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.67it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.38it/s] [epoch 5] train_loss: 1.120, val_accuracy:0.574 train epoch 6/ 10 loss: 0.985: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.62it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.35it/s] [epoch 6] train_loss: 1.063, val_accuracy:0.560 train epoch 7/ 10 loss: 0.724: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.69it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.30it/s] [epoch 7] train_loss: 1.072, val_accuracy:0.607 train epoch 8/ 10 loss: 1.221: 100%|█████████████████████████████████████████████████| 104/104 [00:28<00:00, 3.63it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00, 2.74it/s] [epoch 8] train_loss: 1.039, val_accuracy:0.563 train epoch 9/ 10 loss: 0.656: 100%|█████████████████████████████████████████████████| 104/104 [00:29<00:00, 3.57it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:03<00:00, 3.36it/s] [epoch 9] train_loss: 1.018, val_accuracy:0.635 train epoch 10/ 10 loss: 0.481: 100%|████████████████████████████████████████████████| 104/104 [00:29<00:00, 3.54it/s] 100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00, 2.96it/s] [epoch 10] train_loss: 1.021, val_accuracy:0.593
In [16]:
img_path = os.path.join(image_path, '3.jpeg')
assert os.path.exists(img_path), f'{img_path} does not exist'
img = Image.open(img_path)
plt.imshow(img)
img = data_transform['val'](img)
img = torch.unsqueeze(img, dim=0)
model = AlexNet(num_classes=5).to(device)
assert os.path.exists(save_path), f'file {save_path} does not exist'
model.load_state_dict(torch.load(save_path, weights_only=True))
model.eval()
with torch.no_grad():
output = model(img.to(device))
output = torch.squeeze(output).cpu()
predict = torch.softmax(output, dim=0)
predict_class = torch.argmax(predict).numpy()
class_name = classify_dict[int(predict_class)]
print_res = f"class: {class_name}, prob: {predict[predict_class].numpy():.3f}"
plt.title(print_res)
plt.show()

In [ ]:

浙公网安备 33010602011771号