python KAN网络花朵分类训练

参考:Blealtan/efficient-kan: An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).

哔哩哔哩上面有人用这个代码进行了花朵分类训练:https://www.bilibili.com/video/BV1PS421o77Z/?spm_id_from=333.1387.favlist.content.click

讲的挺好,购买了他的训练代码

经过调整,想在实际中进行预测一下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from src.efficient_kan import KAN  # 确保efficient-kan库已安装
import os
import shutil
from PIL import Image

# 设置设备(GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- 1. 数据准备 --------------------
# 数据预处理(包含数据增强)
transform_train = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomRotation(15),  # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet标准化
])

transform_val = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集(假设路径为 './data/flowers')
trainset = datasets.ImageFolder(root='G:/kuozhi/efficient-kan/examples/train', transform=transform_train)
valset = datasets.ImageFolder(root='G:/kuozhi/efficient-kan/examples/val', transform=transform_val)

# 获取类别数量和名称
num_classes = len(trainset.classes)
class_names = trainset.classes
print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

# 数据加载器
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
valloader = DataLoader(valset, batch_size=32, shuffle=False)

# -------------------- 2. 模型定义 --------------------
# 输入维度:64x64 RGB图像展平后为 64*64*3 = 12288
model = KAN([12288, 64, num_classes])  # 输入层 -> 隐藏层 -> 输出层(类别数)
model.to(device)

# -------------------- 3. 训练配置 --------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)  # 学习率衰减

# -------------------- 4. 训练循环 --------------------
num_epochs = 10
train_loss_history, val_loss_history = [], []
train_acc_history, val_acc_history = [], []

for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(trainloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images = images.view(-1, 12288).to(device)  # 展平输入
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(trainloader)
    train_acc = correct / total
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)

    # 验证阶段
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 12288).to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(valloader)
    val_acc = correct / total
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)

    # 更新学习率
    scheduler.step()

    # 打印日志
    print(f"Epoch {epoch + 1}/{num_epochs}: "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# -------------------- 5. 结果可视化 --------------------
# 绘制损失和准确率曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Train Loss')
plt.plot(val_loss_history, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Train Acc')
plt.plot(val_acc_history, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


# -------------------- 预测并分类保存 --------------------
def classify_and_save_images(model, input_dir, output_dir, transform, class_names, device):
    """
    对输入目录中的图片进行预测,并按类别保存到输出目录
    :param model: 训练好的模型
    :param input_dir: 混合图片目录(如 './hunhe')
    :param output_dir: 分类输出目录(如 './classified_flowers')
    :param transform: 数据预处理
    :param class_names: 类别名称列表
    :param device: 设备(cuda/cpu)
    """
    # 创建输出目录(按类别建立子文件夹)
    os.makedirs(output_dir, exist_ok=True)
    for class_name in class_names:
        os.makedirs(os.path.join(output_dir, class_name), exist_ok=True)

    # 遍历输入目录中的图片
    model.eval()
    for filename in os.listdir(input_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(input_dir, filename)

            # 加载并预处理图片
            img = Image.open(img_path).convert('RGB')
            img_tensor = transform(img).unsqueeze(0)  # 添加batch维度
            img_tensor = img_tensor.view(-1, 12288).to(device)  # 展平

            # 预测
            with torch.no_grad():
                output = model(img_tensor)
                _, pred = torch.max(output, 1)
                predicted_class = class_names[pred.item()]

            # 保存到对应类别文件夹
            dst_dir = os.path.join(output_dir, predicted_class)
            shutil.copy(img_path, dst_dir)
            print(f"Image '{filename}' classified as: {predicted_class}")


# -------------------- 使用示例 --------------------
# 定义预处理(需与训练时一致)
transform_predict = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 调用函数
classify_and_save_images(
    model=model,
    input_dir='G:/kuozhi/efficient-kan/examples/hunhe',  # 混合图片目录
    output_dir='G:/kuozhi/efficient-kan/examples/classified_flowers',  # 分类输出目录
    transform=transform_predict,
    class_names=class_names,  # 从trainset.classes获取
    device=device
)

 

posted @ 2025-04-23 17:28  秋刀鱼CCC  Views(61)  Comments(0)    收藏  举报