1.绘制脚本
import matplotlib.pyplot as plt
import pandas as pd
# 绘制mAP和Loss随epoch变化
def plot_metrics(ax, metric_col_name, y_label, color, modelname):
res_path = pr_csv_dict[modelname]
try:
# 读取数据并去掉列名的空格
data = pd.read_csv(res_path)
data.columns = data.columns.str.strip() # 去掉列名的空格
epochs = data['epoch'].values # epoch 列
metric_data = data[metric_col_name].values # 获取相应的metric列
if len(epochs) > 0 and len(metric_data) > 0:
ax.plot(epochs, metric_data, label=modelname, color=color, linewidth='2')
else:
print(f"No data for {modelname}")
except Exception as e:
print(f"Error reading {modelname}: {e}")
# 主函数
def plot_all_metrics():
global pr_csv_dict
pr_csv_dict = {
'YOLO11n': r'/yolo11/yolo11-1/runs/train/kitti-yolo11n/results.csv',
'YOLO11s': r'/runs/train/kitti-yolo11s/results.csv',
'YOLO11m': r'/runs/train/kitti-yolo11m/results.csv',
'YOLO11l': r'/runs/train/kitti-yolo11l/results.csv',
'YOLO11x': r'/runs/train/kitti-yolo11x/results.csv',
}
colors = {
'YOLO11n': '#8470FF',
'YOLO11s': 'orange',
'YOLO11m': '#BCEE68',
'YOLO11l': '#FF6A6A',
'YOLO11x': '#00BFFF',
}
fig, axs = plt.subplots(1, 3, figsize=(24, 8), tight_layout=True) # 3行1列的子图
# 设置全局字体大小
plt.rcParams.update({'font.size': 16}) # 放大字体
# 绘制mAP@0.5
for modelname in pr_csv_dict:
plot_metrics(axs[0], 'metrics/mAP50(B)', 'mAP@0.5', colors[modelname], modelname)
axs[0].set_xlabel('Epoch', fontsize=16)
axs[0].set_ylabel('mAP@0.5', fontsize=16)
axs[0].set_xlim(0, None)
axs[0].set_ylim(0, 1)
axs[0].legend(bbox_to_anchor=(1, 0), loc='lower right')
axs[0].tick_params(width=2, labelsize=16) # 加粗坐标轴刻度线
# 绘制mAP@0.95
for modelname in pr_csv_dict:
plot_metrics(axs[1], 'metrics/mAP50-95(B)', 'mAP@0.95', colors[modelname], modelname)
axs[1].set_xlabel('Epoch', fontsize=16)
axs[1].set_ylabel('mAP@0.5-0.95', fontsize=16)
axs[1].set_xlim(0, None)
axs[1].set_ylim(0, 1)
axs[1].legend(bbox_to_anchor=(1, 0), loc='lower right')
axs[1].tick_params(width=2, labelsize=16) # 加粗坐标轴刻度线
# 绘制Loss
for modelname in pr_csv_dict:
res_path = pr_csv_dict[modelname]
try:
data = pd.read_csv(res_path)
data.columns = data.columns.str.strip()
epochs = data['epoch'].values
box_loss = data['train/box_loss'].values
cls_loss = data['train/cls_loss'].values
dfl_loss = data['train/dfl_loss'].values
total_loss = box_loss + cls_loss + dfl_loss
if len(epochs) > 0 and len(total_loss) > 0:
axs[2].plot(epochs, total_loss, label=modelname, color=colors[modelname], linewidth='2')
else:
print(f"No data for {modelname}")
except Exception as e:
print(f"Error reading {modelname}: {e}")
axs[2].set_xlabel('Epoch', fontsize=16)
axs[2].set_ylabel('Total loss', fontsize=16)
axs[2].set_xlim(0, None)
axs[2].set_ylim(0, None)
axs[2].legend(bbox_to_anchor=(1, 1), loc='upper right') # Legend for loss in upper right
axs[2].tick_params(width=2, labelsize=16) # 加粗坐标轴刻度线
# 在子图间增加间距
plt.subplots_adjust(wspace=5) # 设置子图之间的水平间距
# 保存图像、改名
plt.savefig('/images/aa/yolo11-map2.png', dpi=250)
plt.show()
# 执行绘图
plot_all_metrics()
- 只需修改pr_csv_dict中的名字、路径、颜色对应的名字、保存路径
2.绘制结果
![]()