ROC曲线绘制

1. 引入相关包

使用matplotlib包作为绘图库,故要引入相关的包

为了使画出的图更为符合期刊要求,这里引入SciencePlots。

它是一个基于Matplotlib的补充包,里面主要包含了一些以.mplstyle为后缀的图表样式的配置文件。这样,你画图的时候只需要通过调用这些配置文件,就能画出比较好看的数据可视化图表,也避免了你每次画图时都要从头开始手动配置图表的格式。

pip install SciencePlots

还要引入numpy对数据进行处理

要计算AUC,还应该引入sklearn中计算相关值的包。

import matplotlib.pyplot as plt
plt.style.use(['science'])
import numpy as np
from sklearn.metrics import roc_curve, auc

然后导入相关数据

# 真实值
y = np.load('npy\\y_test.npy')
# 各种预测值 并非0或1 而是概率
yp_ann = np.load('npy\\ann.npy')
yp_lstm = np.load('npy\\lstm.npy')
yp_lr = np.load('npy\\lr.npy')
yp_rf = np.load('npy\\rf.npy')
yp_xgb = np.load('npy\\xgb.npy')
yp_lgbm = np.load('npy\\lgbm.npy')
yp_catb = np.load('npy\\catb.npy')

2. 计算AUC值

AUC,即AUROC,指的是由TPRFPR围成的ROC曲线下的面积

将分类任务的实际值和预测值作为参数输入给roc_curve()方法可以得到FPR、TPR和对应的阈值。

auc()方法可以计算曲线下的面积,将FPR和TPR作为参数输入,即可获得AUC值。

fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 计算FPR和TPR
auc_1 = auc(fpr_1, tpr_1)  # 计算AUC值

fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
auc_2 = auc(fpr_2, tpr_2)

fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
auc_3 = auc(fpr_3, tpr_3)

fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
auc_4 = auc(fpr_4, tpr_4)

fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
auc_5 = auc(fpr_5, tpr_5)

fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
auc_6 = auc(fpr_6, tpr_6)

fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
auc_7 = auc(fpr_7, tpr_7)

3. 绘制曲线

首先定义曲线的宽度和图的大小,如下所示。

line_width = 1  # 曲线的宽度
plt.figure(figsize=(16, 10))  # 图的大小

使用plt的plot()方法可以绘制曲线,通常可以传入的参数有以下几种:

  • x轴的数据
  • y轴的数据
  • lw:线条的宽度
  • label:曲线的标签(曲线标签甚至支持LaTex公式,例如$K_{d,1}$
  • color:曲线的颜色(如果不指定,plt会自动选择)
  • linestyle:线型,包括“-”代表实线,“--”代表虚线,“-.”代表中间有点的虚线,“:”点型虚线
plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)

4. 坐标轴范围和标题

限定x轴和y轴的范围,如下所示。

plt.xlim([0.0, 1.0])  # 限定x轴的范围
plt.ylim([0.0, 1.0])  # 限定y轴的范围

也可以通过xticks()和yticks()直接调整坐标轴的刻度,如下所示。

# plt.xticks(range(0, 10, 1)) # 修改x轴的刻度
# plt.yticks(range(0, 10, 1)) # 修改y轴的刻度

指定坐标轴的标题,如下所示。

plt.xlabel('False Positive Rate')  # x坐标轴标题
plt.ylabel('True Positive Rate')  # y坐标轴标题

使用grid()方法在图中添加网格,如下所示。

plt.grid()  # 在图中添加网格

显示图例并指定图例位置,常见位置包括{upper,center,lower} {left,center,right},如下所示。

plt.legend(loc="lower right")  # 显示图例并指定图例位置

5. 中文处理问题

如果在坐标轴、标题等地方出现了中文,plt会显示乱码,添加以下两条语句可以解决中文处理问题。

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

6. 展示图片和保存

TIFF格式(Tag Image File Format,TIFF)是常见的论文图片投稿格式,TIFF格式能够制作质量非常高的图像,多数出版社(如Springer、Elsevier)都接受并推荐使用dpi=300的TIFF格式的插图。

plt.savefig('AUC.tif', dpi=300)

使用plt的show方法展示曲线,如下所示。

plt.show()

7. 示例代码

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve

y = np.load('npy\\y_test.npy')
yp_ann = np.load('npy\\ann.npy')
yp_lstm = np.load('npy\\lstm.npy')
yp_lr = np.load('npy\\lr.npy')
yp_rf = np.load('npy\\rf.npy')
yp_xgb = np.load('npy\\xgb.npy')
yp_lgbm = np.load('npy\\lgbm.npy')
yp_catb = np.load('npy\\catb.npy')

fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 计算FPR和TPR
auc_1 = auc(fpr_1, tpr_1)  # 计算AUC值

fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
auc_2 = auc(fpr_2, tpr_2)

fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
auc_3 = auc(fpr_3, tpr_3)

fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
auc_4 = auc(fpr_4, tpr_4)

fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
auc_5 = auc(fpr_5, tpr_5)

fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
auc_6 = auc(fpr_6, tpr_6)

fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
auc_7 = auc(fpr_7, tpr_7)

plt.style.use(['science'])
line_width = 2  # 曲线的宽度
plt.figure(figsize=(8, 5))  # 图的大小

plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)


plt.xlim([0.0, 1.0])  # 限定x轴的范围
plt.ylim([0.0, 1.0])  # 限定y轴的范围
plt.xlabel('False Positive Rate', fontsize=16)  # x坐标轴标题
plt.ylabel('True Positive Rate', fontsize=16)  # y坐标轴标题
plt.title('ROC', fontsize=16)  # 标题
plt.grid()  # 在图中添加网格
plt.legend(loc="lower right", fontsize=16)  # 显示图例并指定图例位置

plt.savefig('ROC.tif', dpi=300)
plt.show()

image-20221026202136784

posted @ 2022-10-26 20:24  王陸  阅读(364)  评论(0编辑  收藏  举报