• 博客园logo
  • 会员
  • 周边
  • 新闻
  • 博问
  • 闪存
  • 众包
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录
华东 博客
目前在某大模型创业公司工作,研究方向大模型、智能体 新浪博客: http://blog.sina.com.cn/u/2463286753
博客园    首页    新随笔    联系   管理    订阅  订阅
python 混淆矩阵可视化

 

 

#!/usr/bin/env python
# coding: utf-8

# In[28]:
from __future__ import division
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np


def plot_confusion_matrix(cm, classes,
                          normalize=True,
                          title='Confusion matrix',
                          cmap=plt.cm.hot):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest')
    # plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    # fmt = '.2f' if normalize else 'd'
    # thresh = cm.max() / 2.
    # for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    #     plt.text(j, i, format(cm[i, j], fmt),
    #              horizontalalignment="center",
    #              color="white" if cm[i, j] > thresh else "black")


data = pd.read_csv(r"experiment\61476-2000\confusion_matrix.csv")

all_categories = data.iloc[:,0]  
confusion = np.array(data.iloc[:,1:])
# print(confusion)
# Normalize by dividing every row by its sum
# for i in range(len(all_categories)):
#     for j in range(len(all_categories)):
#         confusion[i][j] = confusion[i][j] / confusion[i].sum()

#Set up plot

fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(confusion)
# fig.colorbar(cax)

# Set up axes
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)

# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plot_confusion_matrix(confusion, classes=list(all_categories), title="Confusion matrix")
# sphinx_gallery_thumbnail_number = 2
plt.show()

  

posted on 2021-05-13 23:37  华东博客  阅读(230)  评论(0)    收藏  举报
刷新页面返回顶部
博客园  ©  2004-2026
浙公网安备 33010602011771号 浙ICP备2021040463号-3