python KAN网络花朵分类训练
参考:Blealtan/efficient-kan: An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
哔哩哔哩上面有人用这个代码进行了花朵分类训练:https://www.bilibili.com/video/BV1PS421o77Z/?spm_id_from=333.1387.favlist.content.click
讲的挺好,购买了他的训练代码
经过调整,想在实际中进行预测一下:
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from tqdm import tqdm import matplotlib.pyplot as plt from src.efficient_kan import KAN # 确保efficient-kan库已安装 import os import shutil from PIL import Image # 设置设备(GPU/CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -------------------- 1. 数据准备 -------------------- # 数据预处理(包含数据增强) transform_train = transforms.Compose([ transforms.Resize((64, 64)), transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomRotation(15), # 随机旋转 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化 ]) transform_val = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载数据集(假设路径为 './data/flowers') trainset = datasets.ImageFolder(root='G:/kuozhi/efficient-kan/examples/train', transform=transform_train) valset = datasets.ImageFolder(root='G:/kuozhi/efficient-kan/examples/val', transform=transform_val) # 获取类别数量和名称 num_classes = len(trainset.classes) class_names = trainset.classes print(f"Number of classes: {num_classes}") print(f"Class names: {class_names}") # 数据加载器 trainloader = DataLoader(trainset, batch_size=32, shuffle=True) valloader = DataLoader(valset, batch_size=32, shuffle=False) # -------------------- 2. 模型定义 -------------------- # 输入维度:64x64 RGB图像展平后为 64*64*3 = 12288 model = KAN([12288, 64, num_classes]) # 输入层 -> 隐藏层 -> 输出层(类别数) model.to(device) # -------------------- 3. 训练配置 -------------------- criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) # 学习率衰减 # -------------------- 4. 训练循环 -------------------- num_epochs = 10 train_loss_history, val_loss_history = [], [] train_acc_history, val_acc_history = [], [] for epoch in range(num_epochs): # 训练阶段 model.train() running_loss = 0.0 correct = 0 total = 0 for images, labels in tqdm(trainloader, desc=f"Epoch {epoch + 1}/{num_epochs}"): images = images.view(-1, 12288).to(device) # 展平输入 labels = labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / len(trainloader) train_acc = correct / total train_loss_history.append(train_loss) train_acc_history.append(train_acc) # 验证阶段 model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in valloader: images = images.view(-1, 12288).to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_loss = val_loss / len(valloader) val_acc = correct / total val_loss_history.append(val_loss) val_acc_history.append(val_acc) # 更新学习率 scheduler.step() # 打印日志 print(f"Epoch {epoch + 1}/{num_epochs}: " f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | " f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}") # -------------------- 5. 结果可视化 -------------------- # 绘制损失和准确率曲线 plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='Train Loss') plt.plot(val_loss_history, label='Val Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(train_acc_history, label='Train Acc') plt.plot(val_acc_history, label='Val Acc') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.show() # -------------------- 预测并分类保存 -------------------- def classify_and_save_images(model, input_dir, output_dir, transform, class_names, device): """ 对输入目录中的图片进行预测,并按类别保存到输出目录 :param model: 训练好的模型 :param input_dir: 混合图片目录(如 './hunhe') :param output_dir: 分类输出目录(如 './classified_flowers') :param transform: 数据预处理 :param class_names: 类别名称列表 :param device: 设备(cuda/cpu) """ # 创建输出目录(按类别建立子文件夹) os.makedirs(output_dir, exist_ok=True) for class_name in class_names: os.makedirs(os.path.join(output_dir, class_name), exist_ok=True) # 遍历输入目录中的图片 model.eval() for filename in os.listdir(input_dir): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(input_dir, filename) # 加载并预处理图片 img = Image.open(img_path).convert('RGB') img_tensor = transform(img).unsqueeze(0) # 添加batch维度 img_tensor = img_tensor.view(-1, 12288).to(device) # 展平 # 预测 with torch.no_grad(): output = model(img_tensor) _, pred = torch.max(output, 1) predicted_class = class_names[pred.item()] # 保存到对应类别文件夹 dst_dir = os.path.join(output_dir, predicted_class) shutil.copy(img_path, dst_dir) print(f"Image '{filename}' classified as: {predicted_class}") # -------------------- 使用示例 -------------------- # 定义预处理(需与训练时一致) transform_predict = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 调用函数 classify_and_save_images( model=model, input_dir='G:/kuozhi/efficient-kan/examples/hunhe', # 混合图片目录 output_dir='G:/kuozhi/efficient-kan/examples/classified_flowers', # 分类输出目录 transform=transform_predict, class_names=class_names, # 从trainset.classes获取 device=device )

浙公网安备 33010602011771号