CIFAR-10图像分类实验报告
CIFAR-10图像分类实验报告
一、实验目的
掌握深度学习图像分类的基本流程
学习使用CNN网络处理CIFAR-10数据集
掌握模型训练、评估和结果分析方法
学习使用混淆矩阵进行模型性能分析
二、实验环境
编程语言: Python 3.x
深度学习框架: TensorFlow 2.x / Keras
主要库: NumPy, Matplotlib, Scikit-learn, Seaborn
硬件: GPU(可选,可加速训练)
三、数据集介绍
CIFAR-10数据集:
包含10个类别的60000张32×32彩色图像
每个类别6000张图像
训练集:50000张,测试集:10000张
类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车
四、实验步骤
4.1 数据加载与预处理
加载数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
数据预处理
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
标签one-hot编码
y_train = keras.utils.to_categorical(y_train, 10)
y_test_categorical = keras.utils.to_categorical(y_test, 10)
预处理说明:
像素值归一化到[0,1]范围
标签转换为one-hot编码格式
保持原始图像尺寸32×32×3
4.2 网络架构设计
构建的CNN网络包含以下层次:
输入层: 32×32×3
↓
卷积块1:
- Conv2D(32, 3×3, relu) + BatchNorm
- Conv2D(32, 3×3, relu) + MaxPooling + Dropout(0.25)
↓
卷积块2: - Conv2D(64, 3×3, relu) + BatchNorm
- Conv2D(64, 3×3, relu) + MaxPooling + Dropout(0.25)
↓
卷积块3: - Conv2D(128, 3×3, relu) + BatchNorm
- Conv2D(128, 3×3, relu) + MaxPooling + Dropout(0.25)
↓
全连接层: - Flatten
- Dense(256, relu) + BatchNorm + Dropout(0.5)
- Dense(10, softmax)
设计思路:
使用多个卷积块逐步提取特征
批归一化加速收敛并提高稳定性
Dropout层防止过拟合
最终softmax输出10个类别的概率分布
4.3 模型编译
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
参数设置:
优化器:Adam,学习率0.001
损失函数:分类交叉熵
评估指标:准确率
4.4 模型训练
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=50,
validation_data=(x_test, y_test_categorical),
verbose=1
)
训练参数:
批量大小:128
训练轮数:50
验证集:测试集
五、实验结果与分析
5.1 训练过程分析
训练历史曲线显示:
训练准确率从初始约40%逐渐上升
验证准确率同步提升,表明模型有效学习
损失函数稳定下降,未出现明显过拟合
最终训练准确率约95%,验证准确率约85%
5.2 模型性能评估
测试集准确率: 0.8525
测试集损失: 0.4231
5.3 混淆矩阵分析
混淆矩阵显示了模型在各个类别上的分类性能:
precision recall f1-score support
airplane 0.87 0.86 0.86 1000
automobile 0.93 0.92 0.92 1000
bird 0.80 0.78 0.79 1000
cat 0.73 0.70 0.71 1000
deer 0.83 0.83 0.83 1000
dog 0.77 0.78 0.77 1000
frog 0.86 0.89 0.87 1000
horse 0.88 0.88 0.88 1000
ship 0.90 0.92 0.91 1000
truck 0.90 0.91 0.91 1000
accuracy 0.85 10000
macro avg 0.85 0.85 0.85 10000
weighted avg 0.85 0.85 0.85 10000
5.4 各类别性能分析
高性能类别(准确率>90%):
汽车(automobile):93%精确率
轮船(ship):90%精确率
卡车(truck):90%精确率
中等性能类别(准确率80-90%):
飞机(airplane):87%精确率
青蛙(frog):86%精确率
马(horse):88%精确率
鹿(deer):83%精确率
挑战性类别(准确率<80%):
鸟(bird):80%精确率
狗(dog):77%精确率
猫(cat):73%精确率(最具挑战)
六、错误分析
6.1 主要误分类模式
从混淆矩阵观察到的典型错误:
猫 vs 狗:两类动物特征相似,容易混淆
鸟 vs 飞机:某些角度的鸟类与小型飞机形状相似
鹿 vs 马:四足动物间的分类困难
6.2 可能改进方向
数据层面:
使用数据增强技术(旋转、翻转、裁剪等)
增加训练样本多样性
模型层面:
使用更深的网络架构(ResNet、DenseNet)
尝试注意力机制
集成学习方法
训练策略:
学习率调度
更长的训练时间
标签平滑技术
七、结论与总结
7.1 实验成果
成功构建并训练了CIFAR-10图像分类模型
实现了85.25%的测试集准确率
掌握了完整的深度学习图像分类流程
学习了混淆矩阵的分析方法
7.2 技术要点
CNN在图像分类任务中的有效性
批归一化和Dropout对模型泛化能力的重要性
合适的网络深度和结构设计
全面的模型评估方法
7.3 实际意义
本实验展示了深度学习在计算机视觉领域的应用潜力,为更复杂的图像识别任务奠定了基础。模型在保持相对简单结构的同时,达到了较好的分类性能,体现了深度学习方法的强大能力。
八、完整代码
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
设置随机种子以保证结果可重现
tf.random.set_seed(42)
np.random.seed(42)
1. 加载和预处理数据
print("正在加载CIFAR-10数据集...")
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
数据基本信息
print(f"训练集形状: {x_train.shape}")
print(f"测试集形状: {x_test.shape}")
print(f"训练标签形状: {y_train.shape}")
print(f"测试标签形状: {y_test.shape}")
CIFAR-10类别名称
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
数据预处理
归一化像素值到0-1范围
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
将标签转换为one-hot编码
y_train = keras.utils.to_categorical(y_train, 10)
y_test_categorical = keras.utils.to_categorical(y_test, 10)
2. 构建网络模型
print("正在构建网络模型...")
model = keras.Sequential([
# 第一个卷积块
layers.Conv2D(32, (3, 3), activation='relu', padding='same',
input_shape=(32, 32, 3)),
layers.BatchNormalization(),
layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# 第二个卷积块
layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.BatchNormalization(),
layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# 第三个卷积块
layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
layers.BatchNormalization(),
layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
layers.MaxPooling2D((2, 2)),
layers.Dropout(0.25),
# 全连接层
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
])
3. 编译网络
print("正在编译模型...")
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
显示模型结构
model.summary()
4. 训练网络
print("开始训练模型...")
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=50,
validation_data=(x_test, y_test_categorical),
verbose=1
)
5. 评估模型
print("评估模型性能...")
test_loss, test_accuracy = model.evaluate(x_test, y_test_categorical, verbose=0)
print(f"测试集准确率: {test_accuracy:.4f}")
print(f"测试集损失: {test_loss:.4f}")
6. 绘制训练历史
plt.figure(figsize=(12, 4))
准确率曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('模型准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
损失曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('模型损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.tight_layout()
plt.savefig('training_history.png', dpi=300, bbox_inches='tight')
plt.show()
7. 生成预测和混淆矩阵
print("生成预测结果...")
y_pred = model.predict(x_test)
y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = y_test.flatten()
计算混淆矩阵
cm = confusion_matrix(y_true_classes, y_pred_classes)
绘制混淆矩阵
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('CIFAR-10分类混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
8. 分类报告
print("\n分类报告:")
print(classification_report(y_true_classes, y_pred_classes,
target_names=class_names))
9. 显示一些预测样本
plt.figure(figsize=(12, 8))
for i in range(12):
plt.subplot(3, 4, i+1)
plt.imshow(x_test[i])
true_label = class_names[y_true_classes[i]]
pred_label = class_names[y_pred_classes[i]]
color = 'green' if true_label == pred_label else 'red'
plt.title(f'True: {true_label}\nPred: {pred_label}', color=color)
plt.axis('off')
plt.tight_layout()
plt.savefig('sample_predictions.png', dpi=300, bbox_inches='tight')
plt.show()
10. 保存模型
model.save('cifar10_cnn_model.h5')
print("模型已保存为 'cifar10_cnn_model.h5'")