先贴代码,后续找时间把代码解释下,并给出相应的测试代码
这次用的是resnet50在ImageNet1000上的预训练模型
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from matplotlib import pyplot as plt
from tqdm import tqdm
def plot_train_history(epochs_history, train_acc_history, val_acc_history, train_loss_history, val_loss_history):
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.plot(epochs_history, train_acc_history, label='Training Accuracy')
plt.plot(epochs_history, val_acc_history, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_history, train_loss_history, label='Training Loss')
plt.plot(epochs_history, val_loss_history, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
def train():
epochs = 10
batch_size = 64
num_workers = 0
use_gpu = torch.cuda.is_available()
PATH = 'resnet50.pt'
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root='\catVSdog\\data\\train',
transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
test_dataset = datasets.ImageFolder(root='catVSdog\\data\\validation', transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True,
num_workers=num_workers)
resnet50 = torchvision.models.resnet50(pretrained=True)
resnet50.fc = nn.Linear(2048, 2)
ct = 0
for child in resnet50.children():
if ct < 9:
for param in child.parameters():
param.requires_grad = False
ct += 1
if (os.path.exists(PATH)):
print("Previous model existed, loading model...")
resnet50 = torch.load(PATH)
if use_gpu:
print('gpu is available')
resnet50 = resnet50.cuda()
else:
print('gpu is unavailable')
print(resnet50)
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []
x = np.arange(1, epochs + 1)
cirterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet50.parameters(), lr=0.05)
max_accuracy = 0.7
for epoch in tqdm(range(epochs)):
running_loss = 0.0
train_correct = 0
train_total = 0
current_train_total = 0
for step, data in enumerate(tqdm(train_loader), 0):
current_train_total += batch_size
inputs, train_labels = data
if use_gpu:
inputs, labels = Variable(inputs.cuda()), Variable(train_labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(train_labels)
optimizer.zero_grad()
outputs = resnet50(inputs)
_, train_predicted = torch.max(outputs.data, 1)
train_correct += (train_predicted == labels.data).sum()
loss = cirterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_total += train_labels.size(0)
print("current train accuracy: %.3f" % (train_correct.item() / current_train_total),
"running loss:", loss.item(), end=" ")
print("-" * 10, "train result", "-" * 10)
print('train %d epoch loss: %.3f acc: %.3f ' % (
epoch + 1, running_loss / train_total, 100 * train_correct / train_total))
val_correct = 0
val_loss = 0.0
val_total = 0
current_val_total = 0
resnet50.eval()
with torch.no_grad():
for data in tqdm(test_loader):
current_val_total += batch_size
images, labels = data
if use_gpu:
images, labels = Variable(images.cuda()), Variable(labels.cuda())
else:
images, labels = Variable(images), Variable(labels)
outputs = resnet50(images)
_, predicted = torch.max(outputs.data, 1)
loss = cirterion(outputs, labels)
val_loss += loss.item()
val_total += labels.size(0)
val_correct += (predicted == labels.data).sum()
print("current validation accuracy: %.3f" % (val_correct.item() / current_val_total),
"running loss:", loss.item(), end=" ")
print("-" * 10, "validation result", "-" * 10)
print('test %d epoch loss: %.3f acc: %.3f ' % (
epoch + 1, val_loss / val_total, 100 * val_correct / val_total))
train_loss_history.append(running_loss / train_total)
train_acc_history.append(100 * train_correct / train_total)
val_loss_history.append(val_loss / val_total)
val_acc_history.append(100 * val_correct / val_total)
if val_acc_history and max(val_acc_history) > max_accuracy:
max_accuracy = max(val_acc_history)
torch.save(resnet50, 'pytorch4CatanDog\\resnet50.pt')
else:
print("This epoch does not find a better model than last one.")
"""
plt.figure(1)
plt.title('train')
plt.plot(x, train_acc_history, 'r')
plt.plot(x, train_loss_history, 'b')
plt.show()
plt.figure(2)
plt.title('test')
plt.plot(x, val_acc_history, 'r')
plt.plot(x, val_loss_history, 'b')
plt.show()
"""
plot_train_history(epochs_history=x, train_acc_history=train_acc_history, val_acc_history=val_acc_history,
train_loss_history=train_loss_history, val_loss_history=val_loss_history)
if __name__ == "__main__":
train()