计算和可视化准确度、F值、混淆矩阵和ROC曲线

基于MATLAB的分类模型评估,包含准确度、F值、混淆矩阵和ROC曲线的计算与可视化,结合多分类与二分类场景:


一、基础指标计算与混淆矩阵可视化

%% 输入数据准备
trueLabels = [1 0 1 1 0 1 0 0 1 0]; % 真实标签(示例数据)
predictedLabels = [1 0 0 1 0 1 1 0 1 0]; % 预测标签
scores = [0.92, 0.15, 0.45, 0.88, 0.22, 0.73, 0.31, 0.66, 0.89, 0.18]; % 预测得分(二分类概率)

%% 1. 准确度计算
accuracy = sum(trueLabels == predictedLabels) / numel(trueLabels);
fprintf('准确度: %.2f%%
', accuracy*100);

%% 2. F值计算(二分类)
C = confusionmat(trueLabels, predictedLabels);
TP = C(2,2); TN = C(1,1);
FP = C(1,2); FN = C(2,1);
precision = TP/(TP+FP);
recall = TP/(TP+FN);
F1 = 2*(precision*recall)/(precision+recall);
fprintf('F1值: %.3f
', F1);

%% 3. 混淆矩阵可视化(支持多分类)
figure;
cm = confusionmat(trueLabels, predictedLabels);
cmChart = confusionchart(cm);
cmChart.Title = '混淆矩阵';
cmChart.XLabel = '预测类别';
cmChart.YLabel = '真实类别';
colormap([0.2 0.7 0.4; 0.9 0.1 0.1]); % 自定义颜色(绿/红)

二、ROC曲线与AUC计算(二分类)

%% ROC曲线绘制与AUC计算
[~, ~, ~, auc] = perfcurve(trueLabels, scores, 1); % 自动计算阈值点
figure;
plot(perfcurve(trueLabels, scores, 1)); % 自动生成FPR-TPR曲线
hold on;
plot([0 1], [0 1], 'r--'); % 随机猜测参考线
xlabel('假阳性率 (FPR)');
ylabel('真阳性率 (TPR)');
title(sprintf('ROC曲线 (AUC = %.3f)', auc));
grid on;

%% 手动计算ROC曲线(教学示例)
thresholds = unique(scores);
[TPR, FPR] = deal(zeros(length(thresholds),1));
for i = 1:length(thresholds)
    pred = scores >= thresholds(i);
    TP = sum(pred & (trueLabels==1));
    FP = sum(pred & (trueLabels==0));
    TN = sum(~pred & (trueLabels==0));
    FN = sum(~pred & (trueLabels==1));
    TPR(i) = TP/(TP+FN);
    FPR(i) = FP/(FP+TN);
end
figure;
plot(FPR, TPR, 'b-o', 'LineWidth',2);
hold on;
plot([0 1], [0 1], 'r--');
xlabel('FPR'); ylabel('TPR');
title('手动计算ROC曲线');
grid on;

三、多分类扩展方案

%% 多分类混淆矩阵(以3分类为例)
trueLabels = [1 2 3 1 2 3 1 2 3 1];
predictedLabels = [1 2 3 2 2 3 1 1 3 1];
cm = confusionmat(trueLabels, predictedLabels);
figure;
cmPlot = confusionchart(cm);
cmPlot.ClassNames = {'Class1','Class2','Class3'};
cmPlot.Title = '多分类混淆矩阵';
cmPlot.XLabel = '预测类别';
cmPlot.YLabel = '真实类别';
cmPlot.Colors = parula(3); % 使用渐变色

%% 多分类ROC曲线(One-vs-All方法)
numClasses = 3;
fpr = zeros(numClasses, 100);
tpr = zeros(numClasses, 100);
for cls = 1:numClasses
    binaryLabels = (trueLabels == cls);
    [~, ~, ~, auc] = perfcurve(binaryLabels, scores, 1);
    [X,Y] = perfcurve(binaryLabels, scores, 1);
    fpr(cls,:) = X;
    tpr(cls,:) = Y;
end
figure;
hold on;
for cls = 1:numClasses
    plot(fpr(cls,:), tpr(cls,:), 'DisplayName', sprintf('Class%d', cls));
end
plot([0 1], [0 1], 'k--');
xlabel('FPR'); ylabel('TPR');
title('多分类ROC曲线(OvA)');
legend;
grid on;

参考代码 计算和可视化准确度、F值、混淆矩阵和ROC曲线的matlabd代码 www.youwenfan.com/contentcnj/84935.html

四、关键参数说明

  1. 输入数据要求
    • trueLabels:真实标签(数值型或分类变量)
    • predictedLabels:模型预测的类别标签
    • scores:模型输出的概率/得分(需为数值型向量)
  2. 性能优化建议
    • 使用crossval进行交叉验证避免过拟合
    • 通过fitcsvm/fitctree等函数调整模型超参数
    • 对不平衡数据使用classifCost设置类别权重

五、结果解读

指标 理想值 判断标准
准确度 100% 高准确度需结合业务场景验证
F1值 1.0 平衡精确率与召回率的综合指标
AUC值 1.0 >0.9优秀,>0.8良好
混淆矩阵 对角线高 对角线元素代表正确分类数

六、应用场景示例

  1. 医疗诊断:通过AUC评估癌症筛查模型的整体性能
  2. 欺诈检测:用F1值平衡欺诈识别率与误报率
  3. 自动驾驶:混淆矩阵分析信号灯识别错误类型

完整代码可通过替换示例数据适配实际项目,建议结合MATLAB的Classification Learner工具箱进行模型迭代优化。

posted @ 2025-10-21 17:54  lingxingqi  阅读(4)  评论(0)    收藏  举报