Matlab 分类算法
一、分类算法核心概念
分类是监督学习任务,目标是将数据分配到预定义的类别中。关键步骤包括:
- 特征工程:提取/选择区分性强的特征
- 模型训练:学习特征与类别的映射关系
- 评估指标:准确率、精确率、召回率、F1分数、混淆矩阵
二、常用分类算法解析
1. K近邻(KNN)
原理:基于距离度量,将样本分配给其k个最近邻中最常见的类别
公式:
欧氏距离: \(d(\mathbf{x}_i, \mathbf{x}_j) = \sqrt{\sum_{k=1}^{n}(x_{ik} - x_{jk})^2}\)
优点:简单、无需训练、适用于多分类
缺点:计算开销大、对高维数据敏感
MATLAB代码:
% 生成示例数据
rng(1);
data = [randn(100,2)*0.5+1; randn(100,2)*0.5-1];
labels = [ones(100,1); 2*ones(100,1)];
% 划分训练测试集
cv = cvpartition(length(labels), 'HoldOut', 0.3);
trainData = data(training(cv),:);
trainLabels = labels(training(cv));
testData = data(test(cv),:);
testLabels = labels(test(cv));
% KNN模型训练与预测
Mdl = fitcknn(trainData, trainLabels, 'NumNeighbors', 5);
predicted = predict(Mdl, testData);
% 评估
accuracy = sum(predicted == testLabels)/numel(testLabels);
confMat = confusionmat(testLabels, predicted);
disp(['Accuracy: ', num2str(accuracy*100), '%']);
disp('Confusion Matrix:');
disp(confMat);
2. 支持向量机(SVM)
原理:寻找最优超平面最大化类别间隔
公式:
优化目标: \(\min_{\mathbf{w},b} \frac{1}{2}\|\mathbf{w}\|^2 + C\sum_{i=1}^{n}\xi_i\)
约束: \(y_i(\mathbf{w}\cdot\mathbf{x}_i + b) \geq 1 - \xi_i\)
优点:高维有效、泛化能力强
缺点:计算复杂、需参数调优
MATLAB代码:
% 使用相同数据
% 训练SVM模型(线性核)
SVMModel = fitcsvm(trainData, trainLabels, 'KernelFunction', 'linear', ...
'BoxConstraint', 1, 'Standardize', true);
% 预测与评估
predicted = predict(SVMModel, testData);
accuracy = sum(predicted == testLabels)/numel(testLabels);
disp(['SVM Accuracy: ', num2str(accuracy*100), '%']);
% 可视化决策边界
figure;
hgscatter = gscatter(trainData(:,1), trainData(:,2), trainLabels);
hold on;
h = gca;
lim = [h.XLim h.YLim];
[xx,yy] = meshgrid(linspace(lim(1),lim(2),100), linspace(lim(3),lim(4),100));
XGrid = [xx(:), yy(:)];
predGrid = predict(SVMModel, XGrid);
gscatter(xx(:), yy(:), predGrid, [0 0.5 0.1; 0.1 0.5 0]);
title('SVM Decision Boundary');
hold off;
3.决策树算法(Decision Tree)
核心原理
通过递归分割构建树状结构,每个节点根据特征阈值进行二元决策:
- 分裂准则:基尼不纯度(Gini Index)或信息增益(Information Gain)
% 基尼系数公式 (MATLAB实现) function gini = gini_index(labels) classes = unique(labels); prob = histcounts(labels, [classes; max(classes)+1])/length(labels); gini = 1 - sum(prob.^2); end - 停止条件:最大深度/最小样本数/纯度阈值
MATLAB 代码实现
% 训练决策树模型
treeModel = fitctree(irisInputs, irisTargets, ...
'MaxDepth', 5, ...
'MinLeafSize', 10, ...
'SplitCriterion', 'gdi'); % 基尼系数
% 可视化决策树
view(treeModel, 'Mode', 'graph');
% 预测与评估
predicted = predict(treeModel, testInputs);
accuracy = sum(predicted == testTargets)/numel(testTargets);
disp(['决策树准确率: ', num2str(accuracy*100), '%']);
% 特征重要性分析
imp = predictorImportance(treeModel);
bar(imp);
xlabel('特征');
ylabel('重要性得分');
title('特征重要性排序');
完整示例:
% 设置随机种子保证可重复性
rng(42);
% 1. 加载鸢尾花数据集
load fisheriris; % 数据集存储在变量meas(150x4)和species(150x1)中
% 将类别标签转换为类别数组(便于后续处理)
species = categorical(species);
% 2. 划分训练集和测试集(70%训练,30%测试)
cv = cvpartition(species, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);
XTrain = meas(idxTrain, :);
yTrain = species(idxTrain);
XTest = meas(idxTest, :);
yTest = species(idxTest);
% 3. 训练决策树模型(使用基尼指数作为分裂准则)
treeModel = fitctree(XTrain, yTrain, 'SplitCriterion', 'gdi');
% 4. 可视化决策树(生成图形化树结构)
view(treeModel, 'Mode', 'graph');
% 5. 预测与评估
yPred = predict(treeModel, XTest);
accuracy = sum(yPred == yTest) / numel(yTest);
fprintf('测试准确率: %.2f%%\n', accuracy * 100);
% 输出混淆矩阵
C = confusionmat(yTest, yPred);
% 注意:confusionchart需要深度学习工具箱(Deep Learning Toolbox)
if exist('confusionchart', 'file')
figure;
confusionchart(yTest, yPred);
title('决策树混淆矩阵');
else
disp('混淆矩阵:');
disp(C);
end
% 6. 特征重要性(基于节点分裂时特征被选择的次数加权计算)
imp = predictorImportance(treeModel);
featureNames = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
% 绘制特征重要性条形图
figure;
bar(imp);
title('决策树特征重要性');
set(gca, 'XTickLabel', featureNames, 'XTick', 1:numel(featureNames));
ylabel('重要性得分');
% 重要提示:决策树可能会过拟合,可通过剪枝优化
% 计算剪枝水平(交叉验证)
[~, ~, ~, bestLevel] = cvLoss(treeModel, 'SubTrees', 'all', 'TreeSize', 'min');
% 剪枝到最佳水平
prunedTree = prune(treeModel, 'Level', bestLevel);
% 评估剪枝后的树
yPredPruned = predict(prunedTree, XTest);
accuracyPruned = sum(yPredPruned == yTest) / numel(yTest);
fprintf('剪枝后测试准确率: %.2f%%\n', accuracyPruned * 100);
% 可视化剪枝后的树(可选)
% view(prunedTree, 'Mode', 'graph');




决策树特点
| 优势 | 劣势 |
|---|---|
| ✅ 模型可解释性强 | ❌ 容易过拟合 |
| ✅ 处理混合类型特征 | ❌ 对数据波动敏感 |
| ✅ 无需特征缩放 | ❌ 边界只能是轴对齐 |
4.神经网络算法(Neural Network)
核心原理(以多层感知机MLP为例)
- 前向传播:\(z^{(l)} = W^{(l)}a^{(l-1)} + b^{(l)}\)
\(a^{(l)} = \sigma(z^{(l)})\) - 激活函数:ReLU \(\sigma(x) = \max(0,x)\) (隐藏层),Softmax(输出层)
- 损失函数:交叉熵 \(L = -\sum y_i \log(\hat{y}_i)\)
MATLAB深度学习工具箱实现
% 数据准备
[XTrain, YTrain, XTest, YTest] = prepareData(); % 自定义数据预处理
% 构建网络结构
layers = [
featureInputLayer(size(XTrain,2)) % 输入层
fullyConnectedLayer(128) % 全连接层
batchNormalizationLayer % 批标准化
reluLayer % ReLU激活
dropoutLayer(0.3) % Dropout正则化
fullyConnectedLayer(64)
reluLayer
fullyConnectedLayer(numClasses) % 输出层
softmaxLayer
classificationLayer];
% 训练配置
options = trainingOptions('adam', ...
'MaxEpochs', 100, ...
'MiniBatchSize', 64, ...
'ValidationData', {XTest, YTest}, ...
'Plots', 'training-progress', ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.5, ...
'LearnRateDropPeriod', 20);
% 训练网络
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% 测试评估
predicted = classify(net, XTest);
accuracy = sum(predicted == categorical(YTest))/numel(YTest);
confusionchart(YTest, double(predicted));
三、算法对比矩阵
| 特性 | 决策树 | 神经网络 | KNN | SVM |
|---|---|---|---|---|
| 训练速度 | ⚡️⚡️⚡️ | ⚡️ | ⚡️⚡️ | ⚡️⚡️ |
| 预测速度 | ⚡️⚡️⚡️ | ⚡️⚡️ | ⚡️ | ⚡️⚡️⚡️ |
| 可解释性 | ✅✅✅ | ❌ | ✅✅ | ✅ |
| 处理高维 | ❌ | ✅✅✅ | ❌ | ✅✅ |
| 抗噪声 | ❌ | ✅✅ | ❌ | ✅✅ |
| 特征工程 | 无需 | 自动提取 | 需缩放 | 需缩放 |
四、关键问题解决方案
决策树过拟合处理
% 后剪枝策略
prunedTree = prune(treeModel, 'Level', 5); % 层级剪枝
cvTree = crossval(treeModel, 'KFold', 5); % 交叉验证剪枝
loss = kfoldLoss(cvTree);
神经网络梯度消失
- 使用ReLU代替Sigmoid
- 添加Batch Normalization层
- 残差连接(ResNet结构)
% 残差块示例
function lgraph = addResBlock(lgraph, blockName, numFilters)
layers = [
convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv1'])
batchNormalizationLayer('Name',[blockName '_bn1'])
reluLayer('Name',[blockName '_relu1'])
convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv2'])
batchNormalizationLayer('Name',[blockName '_bn2'])
additionLayer(2,'Name',[blockName '_add'])];
lgraph = addLayers(lgraph, layers);
lgraph = connectLayers(lgraph, 'input', [blockName '_conv1']);
lgraph = connectLayers(lgraph, [blockName '_relu1'], [blockName '_conv2']);
lgraph = connectLayers(lgraph, 'input', [blockName '_add/in2']);
end
五、算法选择指南
- 中小型结构化数据 → 决策树(可解释性优先)或SVM(精度优先)
- 图像/语音/文本数据 → 神经网络(CNN/RNN)
- 实时预测场景 → 决策树(毫秒级响应)
- 缺乏ML经验 → KNN(参数简单)或预训练神经网络
graph TD
A[数据类型] --> B{结构化?}
B -->|是| C{需要解释模型?}
C -->|是| D[决策树]
C -->|否| E[SVM/神经网络]
B -->|否| F{时序/空间特征?}
F -->|是| G[CNN/RNN]
F -->|否| H[全连接网络]
实践经验:从决策树基准开始,逐渐尝试更复杂模型。对于表格数据,LightGBM/XGBoost(基于决策树的集成方法)通常优于单一模型,MATLAB可通过调用Python库实现:
pyrun('import lightgbm as lgb') model = pyrun('lgb.LGBMClassifier', [], boosting_type='gbdt', num_leaves=31);
六、分类算法性能对比表
| 算法 | 训练速度 | 预测速度 | 内存需求 | 适用场景 |
|---|---|---|---|---|
| KNN | 快 | 慢 | 高 | 小规模数据 |
| SVM | 慢 | 快 | 低 | 高维数据 |
| 决策树 | 快 | 快 | 低 | 可解释性要求高 |
| 神经网络 | 很慢 | 中等 | 高 | 复杂模式识别 |
七、关键注意事项
- 数据预处理:
% 标准化处理 [trainData, mu, sigma] = zscore(trainData); testData = (testData - mu)./sigma; - 类别不平衡处理:
% 使用代价敏感学习 SVMModel = fitcsvm(trainData, trainLabels, 'Cost', [0 2; 1 0]); - 参数调优(以SVM为例):
% 交叉验证选择最佳参数 opts = struct('Optimizer','bayesopt', 'ShowPlots', true, ... 'CVPartition', cvpartition(trainLabels,'KFold',5)); params = hyperparameters('fitcsvm', trainData, trainLabels); SVMModel = fitcsvm(trainData, trainLabels, 'OptimizeHyperparameters','auto', ... 'HyperparameterOptimizationOptions', opts);
八、进阶技巧
- 多分类问题:
- 使用
fitcecoc进行错误校正输出编码
ECOMModel = fitcecoc(trainData, trainLabels); - 使用
- 特征选择:
% 使用最小冗余最大相关算法 idx = fscmrmr(trainData, trainLabels); selectedData = trainData(:, idx(1:10)); - 模型融合:
% 创建投票分类器 knnModel = fitcknn(trainData, trainLabels); treeModel = fitctree(trainData, trainLabels); ensemble = fitcensemble(trainData, trainLabels, 'Method', 'Subspace');
九、完整工作流示例
% 1. 数据准备
data = readtable('classification_data.csv');
predictors = data(:,1:end-1);
response = data(:,end);
% 2. 特征工程
predictors = fillmissing(predictors, 'constant', 0);
predictors = normalize(predictors);
% 3. 训练/验证集划分
cv = cvpartition(height(response), 'HoldOut', 0.2);
trainPredictors = predictors(training(cv),:);
trainResponse = response(training(cv),:);
valPredictors = predictors(test(cv),:);
valResponse = response(test(cv),:);
% 4. 模型训练与调优
ensemble = fitcensemble(trainPredictors, trainResponse, ...
'OptimizeHyperparameters','all', ...
'HyperparameterOptimizationOptions', ...
struct('AcquisitionFunctionName','expected-improvement-plus'));
% 5. 模型评估
predicted = predict(ensemble, valPredictors);
confusionchart(table2array(valResponse), predicted);
fprintf('F1 Score: %.2f\n', f1score(table2array(valResponse), predicted));
重要提示:实际应用中需根据数据特性选择算法。对于大型数据集推荐使用SVM或集成方法,对于需要解释性的场景可选择决策树,实时系统可考虑KNN或朴素贝叶斯。

浙公网安备 33010602011771号