ROC曲线绘制

ROC曲线绘制调用
源码:
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

'''
当报错显示:predict_proba is not available when  probability=False时,
则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的
分类器中启用probability=True,例如SVM(probability=True)
'''


def plot_multiclass_roc(clf, x_test, y_test, n_classes):
    """
    用于绘制多类分类算法的 ROC 曲线,包括 micro-averaging 和 macro-averaging

    参数:
    clf: 已训练的分类器
    x_test: 测试数据
    y_test: 测试标签
    n_classes: 类别数量
    """
    # 获得分类器的预测概率
    y_score = clf.predict_proba(x_test)

    # 将类别标签进行二进制编码,以便用于 ROC 曲线计算
    y_test_bin = label_binarize(y_test, classes=np.arange(n_classes))

    # 计算每个类别的 ROC 曲线
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # 计算 micro 和 macro 平均 AUC 值
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test_bin.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # 将所有假正类率汇总为一个集合
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

    # 对所有 ROC 曲线进行插值,然后求平均
    mean_tpr = np.zeros_like(all_fpr)
    for i in range(n_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= n_classes

    # 计算 macro 平均值
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # 绘制 ROC 曲线
    plt.figure()
    lw = 2
    colors = cycle(['deeppink', 'aqua', 'darkorange', 'red', 'navy', 'magenta', 'green', 'cyan', 'cornflowerblue'])
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw,
                 label='ROC curve of class {0} (area = {1:0.2f})'
                       ''.format(i, roc_auc[i]))

    plt.plot(fpr["micro"], tpr["micro"],
             label='micro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["micro"]),
             color='deeppink', linestyle=':', linewidth=4)

    plt.plot(fpr["macro"], tpr["macro"],
             label='macro-average ROC curve (area = {0:0.2f})'
                   ''.format(roc_auc["macro"]),
             color='navy', linestyle=':', linewidth=4)

    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic for multi-class')
    plt.legend(loc="lower right")
    plt.show()
View Code

 

这里需要注意的是:

当报错显示:predict_proba is not available when  probability=False时,
则为分类器probability参数被设置为False,导致不能计算预测概率,一些分类器(如SVM)中,predict_proba默认是禁用的
分类器中启用probability=True,例如SVM(probability=True)
调用:
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from roc import plot_multiclass_roc

# 加载葡萄酒数据集
data = load_wine()
X = data.data
y = data.target

# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=421)
# 创建决策树分类器
clf = SVC(gamma='scale',kernel='linear',random_state=421,C=0.32,probability=True)
#probability=True与此调用有关
clf.fit(X_train, y_train)

plot_multiclass_roc(clf, X_test, y_test, n_classes=3)
View Code

 

posted @ 2024-12-22 18:37  一眉师傅  阅读(36)  评论(0)    收藏  举报