Matlab 分类算法


一、分类算法核心概念

分类是监督学习任务,目标是将数据分配到预定义的类别中。关键步骤包括:

  1. 特征工程:提取/选择区分性强的特征
  2. 模型训练:学习特征与类别的映射关系
  3. 评估指标:准确率、精确率、召回率、F1分数、混淆矩阵

二、常用分类算法解析

1. K近邻(KNN)

原理:基于距离度量,将样本分配给其k个最近邻中最常见的类别
公式
欧氏距离: \(d(\mathbf{x}_i, \mathbf{x}_j) = \sqrt{\sum_{k=1}^{n}(x_{ik} - x_{jk})^2}\)
优点:简单、无需训练、适用于多分类
缺点:计算开销大、对高维数据敏感

MATLAB代码

% 生成示例数据
rng(1);
data = [randn(100,2)*0.5+1; randn(100,2)*0.5-1];
labels = [ones(100,1); 2*ones(100,1)];

% 划分训练测试集
cv = cvpartition(length(labels), 'HoldOut', 0.3);
trainData = data(training(cv),:);
trainLabels = labels(training(cv));
testData = data(test(cv),:);
testLabels = labels(test(cv));

% KNN模型训练与预测
Mdl = fitcknn(trainData, trainLabels, 'NumNeighbors', 5);
predicted = predict(Mdl, testData);

% 评估
accuracy = sum(predicted == testLabels)/numel(testLabels);
confMat = confusionmat(testLabels, predicted);
disp(['Accuracy: ', num2str(accuracy*100), '%']);
disp('Confusion Matrix:');
disp(confMat);

2. 支持向量机(SVM)

原理:寻找最优超平面最大化类别间隔
公式
优化目标: \(\min_{\mathbf{w},b} \frac{1}{2}\|\mathbf{w}\|^2 + C\sum_{i=1}^{n}\xi_i\)
约束: \(y_i(\mathbf{w}\cdot\mathbf{x}_i + b) \geq 1 - \xi_i\)
优点:高维有效、泛化能力强
缺点:计算复杂、需参数调优

MATLAB代码

% 使用相同数据
% 训练SVM模型(线性核)
SVMModel = fitcsvm(trainData, trainLabels, 'KernelFunction', 'linear', ...
                  'BoxConstraint', 1, 'Standardize', true);

% 预测与评估
predicted = predict(SVMModel, testData);
accuracy = sum(predicted == testLabels)/numel(testLabels);
disp(['SVM Accuracy: ', num2str(accuracy*100), '%']);

% 可视化决策边界
figure;
hgscatter = gscatter(trainData(:,1), trainData(:,2), trainLabels);
hold on;
h = gca;
lim = [h.XLim h.YLim];
[xx,yy] = meshgrid(linspace(lim(1),lim(2),100), linspace(lim(3),lim(4),100));
XGrid = [xx(:), yy(:)];
predGrid = predict(SVMModel, XGrid);
gscatter(xx(:), yy(:), predGrid, [0 0.5 0.1; 0.1 0.5 0]);
title('SVM Decision Boundary');
hold off;

3.决策树算法(Decision Tree)

核心原理

通过递归分割构建树状结构,每个节点根据特征阈值进行二元决策:

  • 分裂准则:基尼不纯度(Gini Index)或信息增益(Information Gain)
    % 基尼系数公式 (MATLAB实现)
    function gini = gini_index(labels)
        classes = unique(labels);
        prob = histcounts(labels, [classes; max(classes)+1])/length(labels);
        gini = 1 - sum(prob.^2);
    end
    
  • 停止条件:最大深度/最小样本数/纯度阈值

MATLAB 代码实现

% 训练决策树模型
treeModel = fitctree(irisInputs, irisTargets, ...
    'MaxDepth', 5, ...
    'MinLeafSize', 10, ...
    'SplitCriterion', 'gdi');  % 基尼系数

% 可视化决策树
view(treeModel, 'Mode', 'graph');

% 预测与评估
predicted = predict(treeModel, testInputs);
accuracy = sum(predicted == testTargets)/numel(testTargets);
disp(['决策树准确率: ', num2str(accuracy*100), '%']);

% 特征重要性分析
imp = predictorImportance(treeModel);
bar(imp);
xlabel('特征');
ylabel('重要性得分');
title('特征重要性排序');

完整示例:

% 设置随机种子保证可重复性
rng(42);

% 1. 加载鸢尾花数据集
load fisheriris; % 数据集存储在变量meas(150x4)和species(150x1)中

% 将类别标签转换为类别数组(便于后续处理)
species = categorical(species);

% 2. 划分训练集和测试集(70%训练,30%测试)
cv = cvpartition(species, 'HoldOut', 0.3);
idxTrain = training(cv);
idxTest = test(cv);

XTrain = meas(idxTrain, :);
yTrain = species(idxTrain);
XTest = meas(idxTest, :);
yTest = species(idxTest);

% 3. 训练决策树模型(使用基尼指数作为分裂准则)
treeModel = fitctree(XTrain, yTrain, 'SplitCriterion', 'gdi');

% 4. 可视化决策树(生成图形化树结构)
view(treeModel, 'Mode', 'graph');

% 5. 预测与评估
yPred = predict(treeModel, XTest);
accuracy = sum(yPred == yTest) / numel(yTest);
fprintf('测试准确率: %.2f%%\n', accuracy * 100);

% 输出混淆矩阵
C = confusionmat(yTest, yPred);
% 注意:confusionchart需要深度学习工具箱(Deep Learning Toolbox)
if exist('confusionchart', 'file')
    figure;
    confusionchart(yTest, yPred);
    title('决策树混淆矩阵');
else
    disp('混淆矩阵:');
    disp(C);
end

% 6. 特征重要性(基于节点分裂时特征被选择的次数加权计算)
imp = predictorImportance(treeModel);
featureNames = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};

% 绘制特征重要性条形图
figure;
bar(imp);
title('决策树特征重要性');
set(gca, 'XTickLabel', featureNames, 'XTick', 1:numel(featureNames));
ylabel('重要性得分');

% 重要提示:决策树可能会过拟合,可通过剪枝优化
% 计算剪枝水平(交叉验证)
[~, ~, ~, bestLevel] = cvLoss(treeModel, 'SubTrees', 'all', 'TreeSize', 'min');
% 剪枝到最佳水平
prunedTree = prune(treeModel, 'Level', bestLevel);
% 评估剪枝后的树
yPredPruned = predict(prunedTree, XTest);
accuracyPruned = sum(yPredPruned == yTest) / numel(yTest);
fprintf('剪枝后测试准确率: %.2f%%\n', accuracyPruned * 100);

% 可视化剪枝后的树(可选)
% view(prunedTree, 'Mode', 'graph');

image

image

image

image

决策树特点

优势 劣势
✅ 模型可解释性强 ❌ 容易过拟合
✅ 处理混合类型特征 ❌ 对数据波动敏感
✅ 无需特征缩放 ❌ 边界只能是轴对齐

4.神经网络算法(Neural Network)

核心原理(以多层感知机MLP为例)
  • 前向传播\(z^{(l)} = W^{(l)}a^{(l-1)} + b^{(l)}\)
    \(a^{(l)} = \sigma(z^{(l)})\)
  • 激活函数:ReLU \(\sigma(x) = \max(0,x)\) (隐藏层),Softmax(输出层)
  • 损失函数:交叉熵 \(L = -\sum y_i \log(\hat{y}_i)\)

MATLAB深度学习工具箱实现

% 数据准备
[XTrain, YTrain, XTest, YTest] = prepareData(); % 自定义数据预处理

% 构建网络结构
layers = [
    featureInputLayer(size(XTrain,2)) % 输入层
    fullyConnectedLayer(128)          % 全连接层
    batchNormalizationLayer           % 批标准化
    reluLayer                         % ReLU激活
    dropoutLayer(0.3)                % Dropout正则化
    fullyConnectedLayer(64)
    reluLayer
    fullyConnectedLayer(numClasses)   % 输出层
    softmaxLayer
    classificationLayer];

% 训练配置
options = trainingOptions('adam', ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 64, ...
    'ValidationData', {XTest, YTest}, ...
    'Plots', 'training-progress', ...
    'LearnRateSchedule', 'piecewise', ...
    'LearnRateDropFactor', 0.5, ...
    'LearnRateDropPeriod', 20);

% 训练网络
net = trainNetwork(XTrain, categorical(YTrain), layers, options);

% 测试评估
predicted = classify(net, XTest);
accuracy = sum(predicted == categorical(YTest))/numel(YTest);
confusionchart(YTest, double(predicted));

三、算法对比矩阵

特性 决策树 神经网络 KNN SVM
训练速度 ⚡️⚡️⚡️ ⚡️ ⚡️⚡️ ⚡️⚡️
预测速度 ⚡️⚡️⚡️ ⚡️⚡️ ⚡️ ⚡️⚡️⚡️
可解释性 ✅✅✅ ✅✅
处理高维 ✅✅✅ ✅✅
抗噪声 ✅✅ ✅✅
特征工程 无需 自动提取 需缩放 需缩放

四、关键问题解决方案

决策树过拟合处理

% 后剪枝策略
prunedTree = prune(treeModel, 'Level', 5);  % 层级剪枝
cvTree = crossval(treeModel, 'KFold', 5);   % 交叉验证剪枝
loss = kfoldLoss(cvTree);

神经网络梯度消失

  1. 使用ReLU代替Sigmoid
  2. 添加Batch Normalization层
  3. 残差连接(ResNet结构)
% 残差块示例
function lgraph = addResBlock(lgraph, blockName, numFilters)
    layers = [
        convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv1'])
        batchNormalizationLayer('Name',[blockName '_bn1'])
        reluLayer('Name',[blockName '_relu1'])
        convolution2dLayer(3, numFilters, 'Padding','same', 'Name',[blockName '_conv2'])
        batchNormalizationLayer('Name',[blockName '_bn2'])
        additionLayer(2,'Name',[blockName '_add'])];
    
    lgraph = addLayers(lgraph, layers);
    lgraph = connectLayers(lgraph, 'input', [blockName '_conv1']);
    lgraph = connectLayers(lgraph, [blockName '_relu1'], [blockName '_conv2']);
    lgraph = connectLayers(lgraph, 'input', [blockName '_add/in2']);
end

五、算法选择指南

  1. 中小型结构化数据 → 决策树(可解释性优先)或SVM(精度优先)
  2. 图像/语音/文本数据 → 神经网络(CNN/RNN)
  3. 实时预测场景 → 决策树(毫秒级响应)
  4. 缺乏ML经验 → KNN(参数简单)或预训练神经网络
graph TD A[数据类型] --> B{结构化?} B -->|是| C{需要解释模型?} C -->|是| D[决策树] C -->|否| E[SVM/神经网络] B -->|否| F{时序/空间特征?} F -->|是| G[CNN/RNN] F -->|否| H[全连接网络]

实践经验:从决策树基准开始,逐渐尝试更复杂模型。对于表格数据,LightGBM/XGBoost(基于决策树的集成方法)通常优于单一模型,MATLAB可通过调用Python库实现:

pyrun('import lightgbm as lgb')
model = pyrun('lgb.LGBMClassifier', [], boosting_type='gbdt', num_leaves=31);

六、分类算法性能对比表

算法 训练速度 预测速度 内存需求 适用场景
KNN 小规模数据
SVM 高维数据
决策树 可解释性要求高
神经网络 很慢 中等 复杂模式识别

七、关键注意事项

  1. 数据预处理
    % 标准化处理
    [trainData, mu, sigma] = zscore(trainData);
    testData = (testData - mu)./sigma;
    
  2. 类别不平衡处理
    % 使用代价敏感学习
    SVMModel = fitcsvm(trainData, trainLabels, 'Cost', [0 2; 1 0]);
    
  3. 参数调优(以SVM为例):
    % 交叉验证选择最佳参数
    opts = struct('Optimizer','bayesopt', 'ShowPlots', true, ...
                  'CVPartition', cvpartition(trainLabels,'KFold',5));
    params = hyperparameters('fitcsvm', trainData, trainLabels);
    SVMModel = fitcsvm(trainData, trainLabels, 'OptimizeHyperparameters','auto', ...
                      'HyperparameterOptimizationOptions', opts);
    

八、进阶技巧

  1. 多分类问题
    • 使用fitcecoc进行错误校正输出编码
    ECOMModel = fitcecoc(trainData, trainLabels);
    
  2. 特征选择
    % 使用最小冗余最大相关算法
    idx = fscmrmr(trainData, trainLabels);
    selectedData = trainData(:, idx(1:10));
    
  3. 模型融合
    % 创建投票分类器
    knnModel = fitcknn(trainData, trainLabels);
    treeModel = fitctree(trainData, trainLabels);
    ensemble = fitcensemble(trainData, trainLabels, 'Method', 'Subspace');
    

九、完整工作流示例

% 1. 数据准备
data = readtable('classification_data.csv');
predictors = data(:,1:end-1);
response = data(:,end);

% 2. 特征工程
predictors = fillmissing(predictors, 'constant', 0);
predictors = normalize(predictors);

% 3. 训练/验证集划分
cv = cvpartition(height(response), 'HoldOut', 0.2);
trainPredictors = predictors(training(cv),:);
trainResponse = response(training(cv),:);
valPredictors = predictors(test(cv),:);
valResponse = response(test(cv),:);

% 4. 模型训练与调优
ensemble = fitcensemble(trainPredictors, trainResponse, ...
                        'OptimizeHyperparameters','all', ...
                        'HyperparameterOptimizationOptions', ...
                        struct('AcquisitionFunctionName','expected-improvement-plus'));

% 5. 模型评估
predicted = predict(ensemble, valPredictors);
confusionchart(table2array(valResponse), predicted);
fprintf('F1 Score: %.2f\n', f1score(table2array(valResponse), predicted));

重要提示:实际应用中需根据数据特性选择算法。对于大型数据集推荐使用SVM或集成方法,对于需要解释性的场景可选择决策树,实时系统可考虑KNN或朴素贝叶斯。

posted @ 2025-07-15 17:44  屈臣  阅读(115)  评论(0)    收藏  举报