import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

======================== 1. 设备配置与参数设置 ========================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
batch_size = 128
num_classes = 10 # 0-9(对应数字1-10)
num_epochs = 5
lr = 0.001

======================== 2. 数据预处理 ========================

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST官方均值和标准差
])

加载MNIST数据集

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

======================== 3. 模型定义 ========================

3.1 共享CNN特征提取器

class CNNFeatureExtractor(nn.Module):
def init(self):
super(CNNFeatureExtractor, self).init()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 输出: 32 * 14 * 14
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2), # 输出: 64 * 7 * 7
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2) # 输出: 128 * 3 * 3
)
self.flatten_dim = 128 * 3 * 3 # 展平后维度

def forward(self, x):
    x = self.features(x)
    x = x.view(-1, self.flatten_dim)
    return x

3.2 CNN+Softmax

class CNNSoftmax(nn.Module):
def init(self, num_classes):
super(CNNSoftmax, self).init()
self.feature_extractor = CNNFeatureExtractor()
self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
    features = self.feature_extractor(x)
    logits = self.classifier(features)
    return self.softmax(logits), logits

3.3 CNN+Sigmoid

class CNNSigmoid(nn.Module):
def init(self, num_classes):
super(CNNSigmoid, self).init()
self.feature_extractor = CNNFeatureExtractor()
self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
    features = self.feature_extractor(x)
    logits = self.classifier(features)
    return self.sigmoid(logits), logits

3.4 CNN+SVM(多分类Hinge Loss)

class CNNSVM(nn.Module):
def init(self, num_classes):
super(CNNSVM, self).init()
self.feature_extractor = CNNFeatureExtractor()
self.classifier = nn.Linear(self.feature_extractor.flatten_dim, num_classes)

def forward(self, x):
    features = self.feature_extractor(x)
    return self.classifier(features)

多分类SVM Hinge Loss定义

class MultiClassHingeLoss(nn.Module):
def init(self, margin=1.0):
super(MultiClassHingeLoss, self).init()
self.margin = margin

def forward(self, logits, labels):
    one_hot = torch.zeros_like(logits).scatter(1, labels.unsqueeze(1), 1)
    correct_logit = torch.sum(logits * one_hot, dim=1, keepdim=True)
    loss = torch.maximum(torch.tensor(0.0).to(device), self.margin - (correct_logit - logits))
    loss = torch.sum(loss * (1 - one_hot)) / logits.size(0)
    return loss

======================== 4. 训练与评估函数 ========================

4.1 通用CNN训练函数

def train_cnn_model(model, criterion, optimizer, train_loader, num_epochs, is_sigmoid=False):
model.to(device)
model.train()
start_time = time.time()
for epoch in range(num_epochs):
total_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()

        if is_sigmoid:
            target_onehot = torch.zeros(len(target), num_classes).to(device).scatter(1, target.unsqueeze(1), 1)
            output, logits = model(data)
            loss = criterion(logits, target_onehot)
        else:
            if isinstance(model, (CNNSoftmax, CNNSigmoid)):
                output, logits = model(data)
            else:
                logits = model(data)
            loss = criterion(logits, target)

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')

train_time = time.time() - start_time
return model, train_time

4.2 模型评估函数

def evaluate_model(model, test_loader, is_cnn_svm=False):
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)

        if isinstance(model, CNNSoftmax):
            output, _ = model(data)
            preds = torch.argmax(output, dim=1)
        elif isinstance(model, CNNSigmoid):
            output, _ = model(data)
            preds = torch.argmax(output, dim=1)
        elif is_cnn_svm:
            logits = model(data)
            preds = torch.argmax(logits, dim=1)

        all_preds.extend((preds + 1).cpu().numpy())  # 0-9 → 1-10
        all_targets.extend((target + 1).cpu().numpy())

all_preds = np.array(all_preds)
all_targets = np.array(all_preds)

def calculate_metrics(y_true, y_pred, num_classes):
    y_true = y_true - 1
    y_pred = y_pred - 1
    TP = np.zeros(num_classes)
    FP = np.zeros(num_classes)
    FN = np.zeros(num_classes)

    for cls in range(num_classes):
        TP[cls] = np.sum((y_true == cls) & (y_pred == cls))
        FP[cls] = np.sum((y_true != cls) & (y_pred == cls))
        FN[cls] = np.sum((y_true == cls) & (y_pred != cls))

    precision_per_cls = TP / (TP + FP + 1e-8)
    recall_per_cls = TP / (TP + FN + 1e-8)
    f1_per_cls = 2 * (precision_per_cls * recall_per_cls) / (precision_per_cls + recall_per_cls + 1e-8)
    precision = np.mean(precision_per_cls)
    recall = np.mean(recall_per_cls)
    f1 = np.mean(f1_per_cls)
    accuracy = np.sum(TP) / len(y_true)
    return accuracy, precision, recall, f1

accuracy, precision, recall, f1 = calculate_metrics(all_targets, all_preds, num_classes)
return accuracy, precision, recall, f1

======================== 5. 模型训练 ========================

print("=" * 30 + " Training CNN+Softmax " + "=" * 30)
model_softmax = CNNSoftmax(num_classes)
criterion_softmax = nn.CrossEntropyLoss()
optimizer_softmax = optim.Adam(model_softmax.parameters(), lr=lr)
model_softmax, time_softmax = train_cnn_model(model_softmax, criterion_softmax, optimizer_softmax, train_loader,num_epochs)

print("\n" + "=" * 30 + " Training CNN+Sigmoid " + "=" * 30)
model_sigmoid = CNNSigmoid(num_classes)
criterion_sigmoid = nn.BCEWithLogitsLoss()
optimizer_sigmoid = optim.Adam(model_sigmoid.parameters(), lr=lr)
model_sigmoid, time_sigmoid = train_cnn_model(model_sigmoid, criterion_sigmoid, optimizer_sigmoid, train_loader,num_epochs, is_sigmoid=True)

print("\n" + "=" * 30 + " Training CNN+SVM " + "=" * 30)
model_cnn_svm = CNNSVM(num_classes)
criterion_cnn_svm = MultiClassHingeLoss()
optimizer_cnn_svm = optim.Adam(model_cnn_svm.parameters(), lr=lr)
model_cnn_svm, time_cnn_svm = train_cnn_model(model_cnn_svm, criterion_cnn_svm, optimizer_cnn_svm, train_loader,num_epochs)

======================== 6. 模型评估 ========================

print("\n" + "=" * 30 + " Model Evaluation " + "=" * 30)
acc_softmax, pre_softmax, rec_softmax, f1_softmax = evaluate_model(model_softmax, test_loader)
acc_sigmoid, pre_sigmoid, rec_sigmoid, f1_sigmoid = evaluate_model(model_sigmoid, test_loader)
acc_cnn_svm, pre_cnn_svm, rec_cnn_svm, f1_cnn_svm = evaluate_model(model_cnn_svm, test_loader, is_cnn_svm=True)

整理结果

models = ['CNN+Softmax', 'CNN+Sigmoid', 'CNN+SVM']
accuracy_list = [acc_softmax, acc_sigmoid, acc_cnn_svm]
precision_list = [pre_softmax, pre_sigmoid, pre_cnn_svm]
recall_list = [rec_softmax, rec_sigmoid, rec_cnn_svm]
f1_list = [f1_softmax, f1_sigmoid, f1_cnn_svm]
time_list = [time_softmax, time_sigmoid, time_cnn_svm]

打印结果表格

print("\n" + "=" * 90)
print(f"{'Model':<18} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Train Time(s)':<12}")
11ba34b6-6ba3-4afc-972e-a85c70aa4e23

print("=" * 90)
for i in range(len(models)):
print(f"{models[i]:<18} {accuracy_list[i]:<12.4f} {precision_list[i]:<12.4f} {recall_list[i]:<12.4f} {f1_list[i]:<12.4f} {time_list[i]:<12.2f}")
print("=" * 90)

======================== 7. 结果可视化(核心:修改柱状图颜色) ========================

def plot_metrics_comparison(models, accuracy, precision, recall, f1, time_list):
fig, ax = plt.subplots(1, 2, figsize=(18, 7))
x = np.arange(len(models))
width = 0.2

# ========== 子图1:分类指标对比(自定义配色,醒目不刺眼) ==========
# 替换成你想要的颜色,直接改HEX码即可
colors_metrics = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D'] # 蓝/紫/橙/红
ax[0].bar(x - 1.5 * width, accuracy, width, label='Accuracy', color=colors_metrics[0], edgecolor='black', alpha=0.8)
ax[0].bar(x - 0.5 * width, precision, width, label='Precision', color=colors_metrics[1], edgecolor='black', alpha=0.8)
ax[0].bar(x + 0.5 * width, recall, width, label='Recall', color=colors_metrics[2], edgecolor='black', alpha=0.8)
ax[0].bar(x + 1.5 * width, f1, width, label='F1-Score', color=colors_metrics[3], edgecolor='black', alpha=0.8)
ax[0].set_xlabel('Models', fontsize=12, fontweight='bold')
ax[0].set_ylabel('Score', fontsize=12, fontweight='bold')
ax[0].set_title('Classification Metrics Comparison (Digits 1-10)', fontsize=14, fontweight='bold')
ax[0].set_xticks(x)
ax[0].set_xticklabels(models, rotation=15, fontsize=10)
ax[0].legend(fontsize=10)
ax[0].grid(axis='y', linestyle='--', alpha=0.7)

# ========== 子图2:训练时间对比(自定义配色,区分3个模型) ==========
# 替换成你想要的颜色,直接改HEX码即可
colors_time = ['#38A3A5', '#57CC99', '#80ED99'] # 深绿/浅绿/薄荷绿
bars = ax[1].bar(models, time_list, color=colors_time, edgecolor='black', alpha=0.8)
ax[1].set_xlabel('Models', fontsize=12, fontweight='bold')
ax[1].set_ylabel('Training Time (s)', fontsize=12, fontweight='bold')
ax[1].set_title('Training Time Comparison', fontsize=14, fontweight='bold')
ax[1].tick_params(axis='x', rotation=15, labelsize=10)
ax[1].grid(axis='y', linestyle='--', alpha=0.7)
# 添加数值标签
for bar in bars:
    height = bar.get_height()
    ax[1].text(bar.get_x() + bar.get_width() / 2., height + 0.5, f'{height:.2f}', ha='center', va='bottom',fontsize=10)

plt.tight_layout()
plt.savefig('mnist_digit_recognition_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

调用绘图函数

plot_metrics_comparison(models, accuracy_list, precision_list, recall_list, f1_list, time_list)
print("3013")

posted on 2025-12-26 00:36  汤圆233  阅读(1)  评论(0)    收藏  举报