• 博客园logo
  • 会员
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • HarmonyOS
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录

LR233

  • 博客园
  • 联系
  • 订阅
  • 管理

公告

View Post

classification_report()评估报告

1、使用数据生成器后获得标签映射

方法一:

labels = [k for k in train_generator.class_indices]

方法二:

1 labels = [None] * len(test_generator.class_indices)
2 for k, v in test_generator.class_indices.items():
3     labels[v] = k

 

2、模型训练完成后怎样生成评估报告和混淆矩阵

导包

from sklearn.metrics import classification_report, confusion_matrix

2.1、评估报告

预先准备

1 test_gen = ImageDataGenerator(rotation_range=5)
2 test_generator = test_gen.flow_from_directory(test_dir,
3                                                       target_size=(img_size,img_size),
4                                                       shuffle=False,
5                                                       class_mode='categorical')  # 多分类

注意:这里的shuffle必须设置为False,class_mode选择多分类。

方法一:

1 y_test = test_generator.classes
2 y_pred = model.predict(test_generator)
3 y_pred = np.argmax(y_pred, axis=1)
4 print(classification_report(y_test,y_pred,target_names=labels))

结果显示:

    

 方法二:

1 test_generator.reset()
2 pred = model.predict_generator(test_generator, verbose=1)
3 # 输出每个图像的预测类别
4 predicted_class_indices = np.argmax(pred, axis=1)
5 print(classification_report(test_generator.classes,predicted_class_indices,target_names=labels))

结果和上面类似。

2.2、混淆矩阵

1 import matplotlib.pyplot as plt
2 import seaborn as sns
3 
4 plt.figure(figsize=(10,8))
5 sns.heatmap(confusion_matrix(y_test,y_pred),annot=True,fmt='.3g',xticklabels=labels,yticklabels=labels,cmap='viridis')
6 plt.show()

结果显示如下:

    

 

posted on 2023-02-25 12:47  LR233  阅读(443)  评论(0)    收藏  举报

刷新页面返回顶部
 
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3