mthoutai

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

欢迎来到机器学习的世界 
博客主页:卿云阁

欢迎关注点赞收藏⭐️留言

本文由卿云阁原创!

首发时间:2025年10月6日

✉️希望可以和大家一起完成进阶之路!

作者水平很有限,如果发现错误,请留言轰炸哦!万分感谢!


目录

实战内容

定义网络结构

代码实现

训练

数据读取和预处理

 定义模型、损失函数、优化器

指定设备开始训练

 可视化

完整代码

评估

UI页面

实战内容

                 本文主要介绍使用神经网络实现对水质的检测,最后我们设计了一个UI页面。

说明1:这个项目训练集和测试集使用是的相同的图片,因为我的数据太少了,所以这里我这样做

了。实际是不允许的

说明2:这个项目和之前我写的一个项目是相同的,只是下面的代码用的pytorch是2.8的所以,有

些代码做了些改动。

说明3:我还对最后的UI页面做了一些美化,具体的内容上也做了更加详细的介绍。下面一起来看

一下这个项目吧。

提取数据集

       链接:https://pan.baidu.com/s/1DSDl5uKF0qaoyVs3f-L7iQ?pwd=wy46 
       提取码:wy46

数据集介绍

思路:

定义网络结构(network.py):定义神经网络结构

训练(train.py):数据集的读取,训练

评估(evaluation.py):对模型评估

可视化界面(UI.py):可视化的UI页面


定义网络结构

代码实现
import torch
# 定义模型
class WaterQualityNet(torch.nn.Module):
    def __init__(self, input_size=32*32*3, hidden_size=128, num_classes=2):
        super(WaterQualityNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, num_classes)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平成 (batch_size, 3072)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)  # 输出
        return x
if __name__ == '__main__':
    model = WaterQualityNet()
    test_input = torch.randn(2, 3, 32, 32)
    output = model(test_input)
    print("--- network.py 测试结果 ---")
    print(f"模型输出尺寸 (Batch, Classes): {output.shape}")

语言描述

步骤操作 / 组件输入数据形状输出数据形状核心作用
1输入数据 (x)(B,3,32,32)(B,3,32,32)原始图像数据(B 为 batch_size)。
2展平 (Flatten)(B,3,32,32)(B,3072)将图像从空间结构展平,适配 MLP 的全连接层。
3第一层全连接 (fc1)(B,3072)(B,128)将 3072 个输入特征线性映射到 128 个隐藏特征。
4ReLU 激活 (relu)(B,128)(B,128)引入非线性,激活神经元。
5第二层全连接 (fc2)(B,128)(B,128)在隐藏空间中进一步精炼特征。
6ReLU 激活 (relu)(B,128)(B,128)第二次引入非线性。
7第三层全连接 (fc3)(B,128)(B,2)将隐藏特征映射到最终的类别空间。
8输出(B,2)(B,2)最终输出 2 个类别的 Logits(原始分数)。

训练

数据读取和预处理
class WaterQualityDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = ImageFolder(root, transform=transform)
    def __getitem__(self, index):
        return self.dataset[index]
    def __len__(self):
        return len(self.dataset)

 ImageFolder 的作用: 这是一个极其方便的工具,专门设计用来处理按文件夹结构组织的图像

数据集。它会自动完成以下工作:遍历 root 目录下所有的子文件夹。将每个子文件夹的名字识别

为一个类别标签。将子文件夹内的所有图片文件路径收集起来。在内部创建一个从类别名到数字标

签(如 0,1,2...)的映射。因此,self.dataset 实际上已经是一个可以工作的 PyTorch Dataset 对

象,包含了所有的图片路径和对应的标签。

   imageFolder 的 __getitem__ 会自动执行以下步骤:找到 index 对应的图片文件路径。读取该

图片文件。应用在 __init__ 中传入的 transform 转换。返回一个元组:(图像的 Tensor, 对应的数字

标签)。

  len:返回数据集大小

# 数据预处理和加载
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

Resize((32, 32)):将所有图片缩放到 32×32 统一尺寸,便于模型处理。

ToTensor():将图片转换为 PyTorch 张量,并将像素值归一化到 [0,1] 区间。

train_dataset = WaterQualityDataset('D:/dataset', transform=transform)
test_dataset = WaterQualityDataset('D:/dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

train_dataset 和 test_dataset 从 D:/dataset加载数据,使用 DataLoader 进行批量加载。

batch_size=1:每次加载 1 张图片(可以适当增大 batch_size 提高训练效率)。

shuffle=True:打乱训练集顺序,提升泛化能力。

num_workers=0:数据加载的 线程数,如果在 GPU 训练时,可以设大一点加速加载。

 定义模型、损失函数、优化器
# 初始化WaterQualityNet模型
model = WaterQualityNet()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

指定设备开始训练
num_epochs = 10
loss_history = []  # 记录每个 epoch 的损失
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()   # 清空梯度
        outputs = model(images) # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)  # 计算平均损失
    loss_history.append(epoch_loss)
    print(f'Training Loss: {epoch_loss}')

model.train():进入训练模式(启用 Dropout、BatchNorm 等)。

遍历 train_loader:

数据转移到 GPU (images.to(device), labels.to(device))

前向传播 (outputs = model(images))

计算损失 (loss = criterion(outputs, labels))

反向传播 (loss.backward())

更新参数 (optimizer.step())

记录每个 epoch 的平均损失,存入 loss_history。

 可视化
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()

plt.plot() 画出 损失值随 epoch 变化的趋势

横轴是 训练轮数(epoch)

纵轴是 损失值(Loss)

如果损失曲线不断下降,说明模型在收敛

如果损失曲线震荡或者上升,说明可能存在学习率过大、数据不稳定等问题

完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from network import WaterQualityNet
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# 定义水质数据集
class WaterQualityDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = ImageFolder(root, transform=transform)
    def __getitem__(self, index):
        return self.dataset[index]
    def __len__(self):
        return len(self.dataset)
# 数据预处理和加载
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])
train_dataset = WaterQualityDataset('D:/dataset', transform=transform)
test_dataset = WaterQualityDataset('D:/dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
# 初始化WaterQualityNet模型
model = WaterQualityNet()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 10
loss_history = []  # 存储损失值
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    loss_history.append(epoch_loss)
    print(f'Training Loss: {epoch_loss}')
torch.save(model, 'water_quality_full_model.pth')
print('模型已经保存为 water_quality_full_model.pth')
# 可视化损失曲线
plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()


评估

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
# 数据加载
test_loader = DataLoader(
    datasets.ImageFolder('D:/dataset1', transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])),
    batch_size=32, shuffle=False, num_workers=0
)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载模型
model = torch.load('water_quality_full_model.pth')
model.to(device)
model.eval()
# 计算 Top-1 Accuracy
correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Evaluating'):
        outputs = model(images.to(device))
        correct += (outputs.argmax(1) == labels.to(device)).sum().item()
        total += labels.size(0)
# 输出准确率
print(f"✅ Top-1 Accuracy: {correct / total * 100:.2f}%")

 correct 用于累加预测正确的样本数;total用于累加总样本数,这里我也想解释一下为啥不用

softmax,对原始分数(Logits)取最大值,得到的预测类别对 Softmax 概率取最大值得到的预测

类别是完全一样的。既然结果一样,我们当然选择计算量最小的方式:跳过 Softmax 这一步。


UI页面

import sys
import torch
import torchvision.transforms as transforms
from PIL import Image
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QFileDialog, QVBoxLayout, QMessageBox
from PyQt5.QtGui import QPixmap, QFont
from PyQt5.QtCore import Qt
# --- 1. 模型导入与加载 ---
# 导入自定义模型类 WaterQualityNet
try:
    # 假设模型类在 network.py 中
    # 注意:在实际运行环境中,您需要确保 network.py 存在
    # from network import WaterQualityNet
    # 占位符类,用于通过检查,如果实际项目中没有 network.py
    class WaterQualityNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            # 这是一个虚拟的初始化,确保 torch.load 不会失败(如果模型路径正确)
            print("注意:WaterQualityNet 类已定义占位符。如果模型加载失败,请提供实际的模型定义。")
        def forward(self, x):
            # 虚拟前向传播
            return torch.rand(x.size(0), 2)
except ImportError:
    # 错误处理:如果无法找到模型定义文件
    print("错误:未找到 WaterQualityNet 类,请检查 network.py 导入路径是否正确!")
    sys.exit(1)
# PyTorch 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 模型文件名
model_path = "water_quality_full_model.pth"
model = None
try:
    # 核心修复:加载完整模型对象(含自定义类)。
    # 必须设置 weights_only=False,以兼容自定义类加载。
    # 警告: 实际运行时,torch.load 要求 WaterQualityNet 类定义必须完整且可用
    model = torch.load(model_path, map_location=device, weights_only=False)
    model.to(device)
    model.eval()  # 切换到评估模式
    print(f"模型成功加载到 {device} 设备")
except Exception as e:
    # 在 GUI 环境下使用 QMessageBox 更好
    # QMessageBox.critical(None, "模型加载错误", f"无法加载模型文件:{model_path}\n错误详情:{str(e)}")
    print(f"模型加载失败,使用 WaterQualityNet 占位符可能导致加载失败。请确保模型文件和类定义匹配。错误详情:{str(e)}")
    # 如果模型加载失败,为了演示 UI,我们使用一个虚拟模型
    model = WaterQualityNet().to(device)
    model.eval()
# 图像预处理(与训练时保持一致)
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])
# 预测结果映射
QUALITY_MAPPING = {
    0: " 干净水 (一类)",
    1: "⚠️ 轻微污染水 (二类)"
    # 请根据您的实际训练标签进行调整
}
# --- 2. PyQt5 UI 设计与样式 ---
class WaterQualityApp(QWidget):
    def __init__(self):
        super().__init__()
        self.image_path = None
        self.init_ui()
        self.apply_style()
    def init_ui(self):
        # 布局
        layout = QVBoxLayout()
        layout.setSpacing(15)
        # 标题
        self.title_label = QLabel("智能水质图像检测系统")
        self.title_label.setObjectName("TitleLabel")
        layout.addWidget(self.title_label, alignment=Qt.AlignCenter)
        # 图片展示区
        self.image_label = QLabel("等待上传图片...")
        self.image_label.setObjectName("ImageLabel")
        layout.addWidget(self.image_label, stretch=1)
        # 按钮:上传图片
        self.upload_button = QPushButton(" 上传水质图片")
        self.upload_button.setObjectName("UploadButton")
        self.upload_button.clicked.connect(self.load_image)
        layout.addWidget(self.upload_button)
        # 按钮:检测水质
        self.detect_button = QPushButton(" 检测水质类别")
        self.detect_button.setObjectName("DetectButton")
        self.detect_button.clicked.connect(self.detect_water_quality)
        layout.addWidget(self.detect_button)
        # 结果标签
        self.result_label = QLabel("结果将在这里显示...")
        self.result_label.setObjectName("ResultLabel")
        layout.addWidget(self.result_label, alignment=Qt.AlignCenter)
        self.setLayout(layout)
        self.setWindowTitle("智能水质检测系统")
        self.setGeometry(100, 100, 700, 700)
    def apply_style(self):
        # 统一的 CSS 样式表,更易于维护和美化
        self.setStyleSheet("""
            QWidget {
                background-color: #f8f9fa; /* 浅灰色背景 */
                font-family: 'Inter', sans-serif;
            }
            #TitleLabel {
                font-size: 28px;
                font-weight: bold;
                color: #007bff; /* 蓝色主色调 */
                padding: 15px;
            }
            #ImageLabel {
                border: 3px dashed #ced4da; /* 虚线边框 */
                border-radius: 10px;
                background-color: #ffffff; /* 白色内容背景 */
                min-height: 300px;
                font-size: 16px;
                color: #6c757d;
            }
            QPushButton {
                font-size: 16px;
                padding: 12px;
                border-radius: 8px;
                border: none;
                font-weight: 500;
                color: white;
            }
            #UploadButton {
                background-color: #28a745; /* 绿色 */
            }
            #UploadButton:hover {
                background-color: #218838;
            }
            #DetectButton {
                background-color: #007bff; /* 蓝色 */
            }
            #DetectButton:hover {
                background-color: #0069d9;
            }
            #ResultLabel {
                font-size: 22px;
                padding: 10px;
                border-radius: 8px;
                font-weight: bold;
                background-color: #e9ecef; /* 结果区域背景 */
                color: #495057;
                margin-top: 15px;
            }
        """)
    def load_image(self):
        """打开文件对话框,选择水质图像并显示"""
        file_dialog = QFileDialog()
        # 限制只选择图片文件
        image_path, _ = file_dialog.getOpenFileName(
            self,
            "选择水质图片",
            "",
            "图片文件 (*.png *.jpg *.jpeg *.bmp)"
        )
        if image_path:
            # 尝试加载图片
            try:
                pixmap = QPixmap(image_path)
                # 根据 image_label 的大小缩放图片,保持纵横比
                scaled_pixmap = pixmap.scaled(
                    self.image_label.size(),
                    Qt.KeepAspectRatio,
                    Qt.SmoothTransformation
                )
                self.image_label.setPixmap(scaled_pixmap)
                self.image_label.setAlignment(Qt.AlignCenter)
                self.image_label.setText("") # 清除 '等待上传图片...' 文本
                self.image_path = image_path
                # 重置结果标签
                self.result_label.setStyleSheet("font-size: 22px; padding: 10px; border-radius: 8px; font-weight: bold; background-color: #e9ecef; color: #495057; margin-top: 15px;")
                self.result_label.setText("图片已上传,请点击 '检测水质类别'...")
            except Exception as e:
                 QMessageBox.warning(self, "文件错误", f"无法加载图片文件: {str(e)}")
    def detect_water_quality(self):
        """执行模型推理,预测水质类别"""
        if self.image_path and model is not None:
            try:
                # 1. 预处理
                image = Image.open(self.image_path).convert("RGB")
                image_tensor = transform(image).unsqueeze(0).to(device)
                # 2. 模型推理
                with torch.no_grad():
                    output = model(image_tensor)
                    probabilities = torch.softmax(output, dim=1)
                    _, predicted_class_idx = torch.max(output, 1)
                # --- 详细输出模型转换过程 (新增内容) ---
                print("\n" + "="*50)
                print("--- 模型输出到最终预测结果的转换 ---")
                # 1. 原始 Logits (模型输出)
                # output.squeeze().tolist() 将 [1, N] 维度的 Tensor 转换为 Python 列表
                print(f"1. 原始 Logits (模型输出,未归一化分数): {output.squeeze().tolist()}")
                # 2. Softmax 概率
                print(f"2. Softmax 概率 (置信度分布): {probabilities.squeeze().tolist()}")
                # 3. 预测索引
                predicted_index = predicted_class_idx.item()
                print(f"3. 预测索引 (最大概率对应的类别编号): {predicted_index}")
                # ---------------------------------------------
                # 3. 结果映射
                water_quality = QUALITY_MAPPING.get(predicted_index, "未知类别")
                confidence = probabilities[0, predicted_index].item()
                print(f"4. 映射结果: {water_quality} (置信度: {confidence*100:.2f}%)")
                print("="*50)
                # 4. 显示结果(美化样式)
                if predicted_index == 0: # 假设 0 是干净水(合格)
                    color = "#28a745" # 绿色
                    bg_color = "#d4edda" # 浅绿色背景
                else:
                    color = "#dc3545" # 红色
                    bg_color = "#f8d7da" # 浅红色背景
                self.result_label.setStyleSheet(f"color: {color}; font-size: 24px; margin-top: 20px; font-weight: bold; background-color: {bg_color}; border: 1px solid {color}; padding: 15px; border-radius: 8px;")
                self.result_label.setText(f"{water_quality} | 置信度: {confidence*100:.2f}%")
            except Exception as e:
                # 错误处理
                self.result_label.setStyleSheet("color: #6c757d; font-size: 18px; margin-top: 20px; background-color: #fff3cd;")
                self.result_label.setText(f"❌ 检测失败,请检查模型或输入:{str(e)[:50]}...")
                print(f"预测错误详情:{e}")
        else:
              QMessageBox.warning(self, "操作提示", "请先上传一张图片!")
# 运行 PyQt5 应用
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = WaterQualityApp()
    window.show()
    sys.exit(app.exec_())

posted on 2025-11-05 12:02  mthoutai  阅读(20)  评论(0)    收藏  举报