pytorch实践(九) 早停法防止过拟合

早停法:防止过拟合,当验证集 loss 多轮不再下降时,提前停止训练。

 

early_stopping.py

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        """
        参数:
        - patience: 忍耐轮数(如果验证集 loss 多轮不下降,就停止训练)
        - min_delta: 最小变化幅度(小于这个值就认为没有改进)
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss # 第一次调用时,初始化best_loss
        elif val_loss < self.best_loss - self.min_delta:  # 如果当前验证集loss比之前的best_loss降低超过了min_delta(即有明显改进)
            self.best_loss = val_loss
            self.counter = 0  # reset counter
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True  # # 如果连续多轮都没有提升,达到忍耐阈值,就设置early_stop为True

 

train_with_early_stopping.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from neural_network_model import NeuralNetwork
from torchvision.transforms import ToTensor
from early_stopping import EarlyStopping
from tqdm import tqdm

# 下载 FashionMNIST 训练集和测试集
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# 用 DataLoader 封装
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# batch_size=64:每次迭代从训练集中取出 64 个样本。
# shuffle=True:每轮训练(epoch)前会打乱数据顺序,提高训练效果,防止模型记住顺序。
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)


# 设置训练设备、实例化模型和损失函数
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNetwork().to(device) # 创建该网络的一个实例对象并存储到设备

loss_fn = nn.CrossEntropyLoss()  #设置损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)


# 训练函数
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.train()

    total_loss = 0
    correct = 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    avg_loss = total_loss / num_batches
    accuracy = correct / size
    return avg_loss, accuracy

# 测试函数
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()

    total_loss = 0
    correct = 0

    with torch.no_grad():
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)

            pred = model(X)
            loss = loss_fn(pred, y)

            total_loss += loss.item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    avg_loss = total_loss / num_batches
    accuracy = correct / size
    return avg_loss, accuracy



# 准备画图数据
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []


# 训练过程
early_stopping = EarlyStopping(patience=3) # 设置早停
epochs = 30
best_accuracy = 0
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")

    train_loss, train_acc = train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loss, test_acc = test_loop(test_dataloader, model, loss_fn)

    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)

     # 保存当前最佳模型
    if test_acc > best_accuracy:
        best_accuracy = test_acc
        torch.save(model.state_dict(), "best_model.pth")
        print("✅ 模型已保存(当前最佳准确率)")

    early_stopping(test_loss)
    if early_stopping.early_stop:
        print("🛑 验证集 loss 多轮未提升,训练提前停止!")
        break

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    print("-" * 50)


# 设置支持中文显示
plt.rcParams['font.family'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False

# 绘制 loss / accuracy 曲线
epochs_range = range(1, len(train_losses) + 1)

plt.figure(figsize=(10, 4))

# Loss 曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label="Train Loss 训练损失")
plt.plot(epochs_range, test_losses, label="Test Loss 测试损失")
plt.xlabel("Epoch 训练轮数")
plt.ylabel("Loss 损失")
plt.title("Loss Curve 损失曲线")
plt.legend()
plt.grid(True)

# Accuracy 曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, label="Train Acc 训练准确率")
plt.plot(epochs_range, test_accuracies, label="Test Acc 测试准确率")
plt.xlabel("Epoch 训练轮数")
plt.ylabel("Accuracy 准确率")
plt.title("Accuracy Curve 准确率曲线")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

 

输出效果:

image

 

posted @ 2025-07-25 16:23  daviyoung  阅读(167)  评论(0)    收藏  举报