计算和可视化准确度、F值、混淆矩阵和ROC曲线
基于MATLAB的分类模型评估,包含准确度、F值、混淆矩阵和ROC曲线的计算与可视化,结合多分类与二分类场景:
一、基础指标计算与混淆矩阵可视化
%% 输入数据准备
trueLabels = [1 0 1 1 0 1 0 0 1 0]; % 真实标签(示例数据)
predictedLabels = [1 0 0 1 0 1 1 0 1 0]; % 预测标签
scores = [0.92, 0.15, 0.45, 0.88, 0.22, 0.73, 0.31, 0.66, 0.89, 0.18]; % 预测得分(二分类概率)
%% 1. 准确度计算
accuracy = sum(trueLabels == predictedLabels) / numel(trueLabels);
fprintf('准确度: %.2f%%
', accuracy*100);
%% 2. F值计算(二分类)
C = confusionmat(trueLabels, predictedLabels);
TP = C(2,2); TN = C(1,1);
FP = C(1,2); FN = C(2,1);
precision = TP/(TP+FP);
recall = TP/(TP+FN);
F1 = 2*(precision*recall)/(precision+recall);
fprintf('F1值: %.3f
', F1);
%% 3. 混淆矩阵可视化(支持多分类)
figure;
cm = confusionmat(trueLabels, predictedLabels);
cmChart = confusionchart(cm);
cmChart.Title = '混淆矩阵';
cmChart.XLabel = '预测类别';
cmChart.YLabel = '真实类别';
colormap([0.2 0.7 0.4; 0.9 0.1 0.1]); % 自定义颜色(绿/红)
二、ROC曲线与AUC计算(二分类)
%% ROC曲线绘制与AUC计算
[~, ~, ~, auc] = perfcurve(trueLabels, scores, 1); % 自动计算阈值点
figure;
plot(perfcurve(trueLabels, scores, 1)); % 自动生成FPR-TPR曲线
hold on;
plot([0 1], [0 1], 'r--'); % 随机猜测参考线
xlabel('假阳性率 (FPR)');
ylabel('真阳性率 (TPR)');
title(sprintf('ROC曲线 (AUC = %.3f)', auc));
grid on;
%% 手动计算ROC曲线(教学示例)
thresholds = unique(scores);
[TPR, FPR] = deal(zeros(length(thresholds),1));
for i = 1:length(thresholds)
pred = scores >= thresholds(i);
TP = sum(pred & (trueLabels==1));
FP = sum(pred & (trueLabels==0));
TN = sum(~pred & (trueLabels==0));
FN = sum(~pred & (trueLabels==1));
TPR(i) = TP/(TP+FN);
FPR(i) = FP/(FP+TN);
end
figure;
plot(FPR, TPR, 'b-o', 'LineWidth',2);
hold on;
plot([0 1], [0 1], 'r--');
xlabel('FPR'); ylabel('TPR');
title('手动计算ROC曲线');
grid on;
三、多分类扩展方案
%% 多分类混淆矩阵(以3分类为例)
trueLabels = [1 2 3 1 2 3 1 2 3 1];
predictedLabels = [1 2 3 2 2 3 1 1 3 1];
cm = confusionmat(trueLabels, predictedLabels);
figure;
cmPlot = confusionchart(cm);
cmPlot.ClassNames = {'Class1','Class2','Class3'};
cmPlot.Title = '多分类混淆矩阵';
cmPlot.XLabel = '预测类别';
cmPlot.YLabel = '真实类别';
cmPlot.Colors = parula(3); % 使用渐变色
%% 多分类ROC曲线(One-vs-All方法)
numClasses = 3;
fpr = zeros(numClasses, 100);
tpr = zeros(numClasses, 100);
for cls = 1:numClasses
binaryLabels = (trueLabels == cls);
[~, ~, ~, auc] = perfcurve(binaryLabels, scores, 1);
[X,Y] = perfcurve(binaryLabels, scores, 1);
fpr(cls,:) = X;
tpr(cls,:) = Y;
end
figure;
hold on;
for cls = 1:numClasses
plot(fpr(cls,:), tpr(cls,:), 'DisplayName', sprintf('Class%d', cls));
end
plot([0 1], [0 1], 'k--');
xlabel('FPR'); ylabel('TPR');
title('多分类ROC曲线(OvA)');
legend;
grid on;
参考代码 计算和可视化准确度、F值、混淆矩阵和ROC曲线的matlabd代码 www.youwenfan.com/contentcnj/84935.html
四、关键参数说明
- 输入数据要求:
trueLabels:真实标签(数值型或分类变量)predictedLabels:模型预测的类别标签scores:模型输出的概率/得分(需为数值型向量)
- 性能优化建议:
- 使用
crossval进行交叉验证避免过拟合 - 通过
fitcsvm/fitctree等函数调整模型超参数 - 对不平衡数据使用
classifCost设置类别权重
- 使用
五、结果解读
| 指标 | 理想值 | 判断标准 |
|---|---|---|
| 准确度 | 100% | 高准确度需结合业务场景验证 |
| F1值 | 1.0 | 平衡精确率与召回率的综合指标 |
| AUC值 | 1.0 | >0.9优秀,>0.8良好 |
| 混淆矩阵 | 对角线高 | 对角线元素代表正确分类数 |
六、应用场景示例
- 医疗诊断:通过AUC评估癌症筛查模型的整体性能
- 欺诈检测:用F1值平衡欺诈识别率与误报率
- 自动驾驶:混淆矩阵分析信号灯识别错误类型
完整代码可通过替换示例数据适配实际项目,建议结合MATLAB的Classification Learner工具箱进行模型迭代优化。

浙公网安备 33010602011771号