pytorch实践(九) 早停法防止过拟合
早停法:防止过拟合,当验证集 loss 多轮不再下降时,提前停止训练。
early_stopping.py
class EarlyStopping: def __init__(self, patience=3, min_delta=0): """ 参数: - patience: 忍耐轮数(如果验证集 loss 多轮不下降,就停止训练) - min_delta: 最小变化幅度(小于这个值就认为没有改进) """ self.patience = patience self.min_delta = min_delta self.counter = 0 self.best_loss = None self.early_stop = False def __call__(self, val_loss): if self.best_loss is None: self.best_loss = val_loss # 第一次调用时,初始化best_loss elif val_loss < self.best_loss - self.min_delta: # 如果当前验证集loss比之前的best_loss降低超过了min_delta(即有明显改进) self.best_loss = val_loss self.counter = 0 # reset counter else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True # # 如果连续多轮都没有提升,达到忍耐阈值,就设置early_stop为True
train_with_early_stopping.py
import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt from neural_network_model import NeuralNetwork from torchvision.transforms import ToTensor from early_stopping import EarlyStopping from tqdm import tqdm # 下载 FashionMNIST 训练集和测试集 training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() ) # 用 DataLoader 封装 train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # batch_size=64:每次迭代从训练集中取出 64 个样本。 # shuffle=True:每轮训练(epoch)前会打乱数据顺序,提高训练效果,防止模型记住顺序。 test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) # 设置训练设备、实例化模型和损失函数 device = "cuda" if torch.cuda.is_available() else "cpu" model = NeuralNetwork().to(device) # 创建该网络的一个实例对象并存储到设备 loss_fn = nn.CrossEntropyLoss() #设置损失函数 optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) # 训练函数 def train_loop(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) num_batches = len(dataloader) model.train() total_loss = 0 correct = 0 for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() avg_loss = total_loss / num_batches accuracy = correct / size return avg_loss, accuracy # 测试函数 def test_loop(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() total_loss = 0 correct = 0 with torch.no_grad(): for X, y in tqdm(dataloader): X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) total_loss += loss.item() correct += (pred.argmax(1) == y).type(torch.float).sum().item() avg_loss = total_loss / num_batches accuracy = correct / size return avg_loss, accuracy # 准备画图数据 train_losses, train_accuracies = [], [] test_losses, test_accuracies = [], [] # 训练过程 early_stopping = EarlyStopping(patience=3) # 设置早停 epochs = 30 best_accuracy = 0 for epoch in range(epochs): print(f"Epoch {epoch+1}/{epochs}") train_loss, train_acc = train_loop(train_dataloader, model, loss_fn, optimizer) test_loss, test_acc = test_loop(test_dataloader, model, loss_fn) train_losses.append(train_loss) train_accuracies.append(train_acc) test_losses.append(test_loss) test_accuracies.append(test_acc) # 保存当前最佳模型 if test_acc > best_accuracy: best_accuracy = test_acc torch.save(model.state_dict(), "best_model.pth") print("✅ 模型已保存(当前最佳准确率)") early_stopping(test_loss) if early_stopping.early_stop: print("🛑 验证集 loss 多轮未提升,训练提前停止!") break print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}") print("-" * 50) # 设置支持中文显示 plt.rcParams['font.family'] = 'SimHei' plt.rcParams['axes.unicode_minus'] = False # 绘制 loss / accuracy 曲线 epochs_range = range(1, len(train_losses) + 1) plt.figure(figsize=(10, 4)) # Loss 曲线 plt.subplot(1, 2, 1) plt.plot(epochs_range, train_losses, label="Train Loss 训练损失") plt.plot(epochs_range, test_losses, label="Test Loss 测试损失") plt.xlabel("Epoch 训练轮数") plt.ylabel("Loss 损失") plt.title("Loss Curve 损失曲线") plt.legend() plt.grid(True) # Accuracy 曲线 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_accuracies, label="Train Acc 训练准确率") plt.plot(epochs_range, test_accuracies, label="Test Acc 测试准确率") plt.xlabel("Epoch 训练轮数") plt.ylabel("Accuracy 准确率") plt.title("Accuracy Curve 准确率曲线") plt.legend() plt.grid(True) plt.tight_layout() plt.show()
输出效果:


浙公网安备 33010602011771号