如何画出你想要的混淆矩阵(热力图)
1. 什么是混淆矩阵?
答:混淆矩阵也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。 具体评价指标有总体精度、制图精度、用户精度等,这些精度指标从不同的侧面反映了图像分类的精度。 在人工智能中,混淆矩阵(confusion matrix)是可视化工具,特别用于多分类的监督学习,在无监督学习一般叫做匹配矩阵。
2. 混淆矩阵的作用:
答:混淆矩阵的作用有以下几条:
- 性能评估: 混淆矩阵提供了一种直观的方式来评估分类模型的性能。通过查看 TP、FP、FN 和 TN 的数量,可以计算出一系列评估指标,如准确率、精确度、召回率、F1 分数等。
- 类别间的差异: 对于多类别分类问题,混淆矩阵能够清晰地展示不同类别之间的预测关系。这有助于识别模型在哪些类别上表现较好,哪些类别上表现较差。
- 调整模型参数: 通过分析混淆矩阵,可以帮助调整模型的阈值或参数,以优化模型在不同类别上的表现。
- 误判分析: 可以通过混淆矩阵分析模型在哪些类别上容易发生误判,从而改进模型的弱点。
3. 生成混淆矩阵的python代码
-
从PKL文件中读取构成混淆矩阵的真实值和预测值
import pickle
# 打开 PKL 文件以读取数据
with open('path/to/your.pkl', 'rb') as file:
loaded_data = pickle.load(file)
# 将真实值和预测值分别存入labels和predictions列表中
labels=loaded_data['labels']
predictions=loaded_data['predictions']
-
绘制混淆矩阵
# 导入相关包
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
# seaborn是一个基于 Matplotlib 的数据可视化库,专注于统计图形的绘制。
# 它提供了一些高层次的接口,能够使创建漂亮和有吸引力的图形变得更加简单。
import seaborn as sns
# 用于生成颜色
import colorsys
# 生成18种低饱和度颜色
def generate_distinct_low_saturation_colors(num_colors):
colors = []
# 分为两组亮度级别,以增加颜色的区分度
lightness_levels = [0.75, 0.85] # 较高亮度
saturation = 0.2 # 低饱和度
for i in range(num_colors):
hue = i / num_colors
lightness = lightness_levels[i % len(lightness_levels)]
rgb = colorsys.hls_to_rgb(hue, lightness, saturation)
colors.append(rgb)
return colors
# 假设cm是你的热力图数据
cm = confusion_matrix(labels, predictions)
# 归一化混淆矩阵
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
# 获取18种低饱和度颜色
colors = generate_distinct_low_saturation_colors(18)
# 绘制热力图
annot = False
# 设置画布大小
plt.figure(figsize=(6, 5))
# 使用seaborn的heatmap函数绘制混淆矩阵
#### annot=False表示热力图中色块仅显示颜色,不显示数值
#### 在 seaborn.heatmap 函数中,fmt 参数用于设置热力图上显示数字的格式。
###### fmt=".1f" 指的是将数字格式化为小数点后一位的浮点数。
#### 在 seaborn.heatmap 函数中,cbar 参数是一个布尔值
###### 用于控制是否在热力图旁边绘制一个颜色条(colorbar)。
#### cmap="Blues" 定义了数据值到颜色的映射关系。
###### 常用cmap值:"viridis", "plasma", "inferno", "magma","Greens", "Reds", "Oranges", "Purples", "Greys", "YlOrBr"
###### 更多的cmap颜色请关注:https://matplotlib.org/stable/users/explain/colors/colormaps.html
sns.heatmap(cm_normalized,annot=False,fmt=".1f" if annot else "", cbar=True, xticklabels=False, yticklabels=False,cmap="Blues"
# 添加色块作为列标签
'''
# plt.Rectangle((idx, -0.5), 1, -0.3, color=color) 这里创建了一个 Rectangle 对象,表示一个矩形。
# 参数 (idx, -0.5) 表示矩形的左下角坐标,其中 idx 是 x 坐标,-0.5 是 y 坐标。
# 1 是矩形的宽度,-0.3 是矩形的高度。注意高度为负值,因为 y 轴的正方向是向上的,我们想要在坐标轴的下方添加矩形。
# color=color 指定了矩形的颜色,它从 colors 列表中获取。
# transform=plt.gca().transData._b 这个参数设置了矩形的坐标变换。
# plt.gca().transData 是当前坐标轴的数据坐标变换。
# _b 是变换的一个内部属性,用于确保矩形正确地放置在数据坐标中。
# clip_on=False 这个参数设置为 False 以确保矩形不会被剪裁,即使它超出了坐标轴的边界。
'''
for idx, color in enumerate(colors):
plt.gca().add_patch(plt.Rectangle((idx, -0.5), 1, -0.3, color=color, transform=plt.gca().transData._b, clip_on=False))
# 同理,添加色块作为行标签
for idx, color in enumerate(colors):
plt.gca().add_patch(plt.Rectangle((-0.5, idx), -0.3, 1, color=color, transform=plt.gca().transData._b, clip_on=False))
# 保存混淆矩阵图,图名自定义,dpi代表图片清晰度,数值越大越清晰
plt.savefig("path/to/confusion_matrix.png",dpi=1200)
plt.show()

浙公网安备 33010602011771号