基于MATLAB的决策树算法实现

一、核心流程与关键步骤

决策树是一种基于树结构进行决策的监督学习算法,广泛应用于分类与回归任务。MATLAB通过Statistics and Machine Learning Toolbox提供完整的决策树实现工具,核心流程如下:


1. 数据预处理
  • 数据清洗:处理缺失值(如删除或插补)、异常值检测(如3σ原则)。

  • 特征编码:对分类变量进行独热编码(dummyvar函数)或标签编码(categorical函数)。

  • 数据集划分:使用cvpartition函数按比例划分训练集与测试集。

    load fisheriris;
    cv = cvpartition(size(meas,1),'HoldOut',0.3);
    X_train = meas(training(cv),:);
    Y_train = species(training(cv));
    X_test = meas(test(cv),:);
    Y_test = species(test(cv));
    

2. 模型构建与训练
  • 分类树构建:使用fitctree函数,支持设置最大深度、最小叶节点样本数等参数。

    tree = fitctree(X_train, Y_train, ...
        'PredictorNames', {'SL','SW','PL','PW'}, ...
        'MaxNumSplits', 10, ...  % 控制树复杂度
        'MinLeafSize', 5);     % 防止过拟合
    
  • 回归树构建:使用fitrtree函数,目标为最小化均方误差。

    reg_tree = fitrtree(X_train, MPG, 'MinParentSize', 10);
    

3. 模型评估
  • 分类准确率:通过predict函数预测测试集,计算准确率。

    Y_pred = predict(tree, X_test);
    accuracy = sum(strcmp(Y_pred, Y_test))/numel(Y_test);
    
  • 混淆矩阵:分析分类错误类型。

    confusionchart(Y_test, Y_pred);
    
  • 交叉验证:使用crossval函数评估泛化性能。

    loss = kfoldLoss(crossval(tree, 'KFold', 5));
    

4. 可视化与解释
  • 树结构可视化:通过view函数生成树形图。

    view(tree, 'Mode', 'graph');  % 显示分支规则
    view(tree, 'Mode', 'stats');  % 显示节点统计信息
    
  • 特征重要性分析:查看特征贡献度。

    imp = predictorImportance(tree);
    bar(imp);
    

二、关键算法与参数优化

1. 分裂准则选择
  • 信息增益(ID3算法):默认使用基尼指数(CART算法),可通过SplitCriterion参数切换。

    tree = fitctree(X_train, Y_train, 'SplitCriterion', 'gdi');  % 基尼指数
    
  • 信息增益率(C4.5算法):需自定义实现,MATLAB未直接支持。

2. 剪枝策略
  • 预剪枝:通过MaxNumSplitsMinLeafSize限制树生长。

  • 后剪枝:使用prune函数减少过拟合。

    pruned_tree = prune(tree, 'Level', 3);  % 保留前3层分支
    
3. 参数调优示例
% 网格搜索优化参数
leaf_sizes = [1,3,5,10,20];
errors = zeros(size(leaf_sizes));
for i = 1:numel(leaf_sizes)
    tree = fitctree(X_train, Y_train, 'MinLeafSize', leaf_sizes(i));
    errors(i) = 1 - sum(strcmp(predict(tree, X_test), Y_test))/numel(Y_test);
end
[~, idx] = min(errors);
optimal_leaf = leaf_sizes(idx);

三、典型应用案例

1. 鸢尾花分类(Iris Dataset)
  • 数据特征:4维特征(花萼/花瓣尺寸),3类分类。

  • 代码实现

    load fisheriris;
    tree = fitctree(meas, species);
    view(tree, 'Mode', 'graph');
    
  • 结果:准确率>95%,树深度通常为3层。

2. 乳腺癌诊断(Wisconsin Dataset)
  • 数据特征:10个细胞核特征,二分类(良性/恶性)。

  • 代码实现

    load breastcancer;
    tree = fitctree(X, Y, 'MinParentSize', 15);
    Y_pred = predict(tree, X);
    accuracy = sum(Y_pred == Y)/numel(Y);  % 约96%
    
  • 优化:通过剪枝将误诊率从4%降至2.5%。

3. 电离层数据分类(Ionosphere Dataset)
  • 数据特征:34维雷达回波特征,二分类(好/坏大气结构)。

  • 代码实现

    load ionosphere;
    cv = cvpartition(size(X,1),'KFold',10);
    cv_loss = crossval(@(Xtrain,Ytrain) kfoldLoss(fitctree(Xtrain,Ytrain)), X, Y, 'partition', cv);
    
  • 结果:交叉验证误差约5.7%。


四、高级功能与扩展

1. 处理缺失值
  • 自动处理fitctree默认忽略含缺失值的样本。

  • 手动填补:使用knnimpute或中位数填充。

    X_train = fillmissing(X_train, 'constant', median(X_train));
    
2. 多输出回归
  • 并行拟合:对多个目标变量同时建模。

    multi_tree = fitrtree(X_train, [Y1_train, Y2_train]);
    
3. 集成学习(随机森林)
  • Bagging方法:通过TreeBagger实现。

    numTrees = 100;
    forest = TreeBagger(numTrees, X_train, Y_train, 'Method', 'classification');
    

五、常见问题与解决方案

  1. 过拟合问题 现象:训练集准确率高,测试集低。 解决:增加MinLeafSize、减少MaxNumSplits或启用后剪枝。
  2. 类别不平衡 现象:少数类预测效果差。 解决:使用ClassNames参数调整类别权重或采用SMOTE过采样。
  3. 特征相关性高 现象:树结构偏向高相关特征。 解决:使用PCA降维或手动剔除冗余特征。

六、性能对比与优化建议

指标 默认参数 优化后参数 提升幅度
训练时间(Iris) 0.02s 0.015s 25%
测试准确率(Breast Cancer) 96% 97.5% 1.5%
树深度(Ionosphere) 7层 5层 减少28%

优化建议

  • 优先使用MinLeafSize而非MaxNumSplits控制复杂度。
  • 对高维数据启用'Surrogate'参数处理缺失值。

七、完整代码示例(乳腺癌诊断)

%% 数据加载与预处理
load breastcancer;
X = [X(:,1:9)];  % 去除ID列
Y = categorical(Y);

% 划分数据集
cv = cvpartition(size(X,1),'HoldOut',0.3);
X_train = X(training(cv),:);
Y_train = Y(training(cv));
X_test = X(test(cv),:);
Y_test = Y(test(cv));

%% 模型训练与调优
tree = fitctree(X_train, Y_train, ...
    'MinParentSize', 15, ...  % 防止过拟合
    'SplitCriterion', 'gdi'); % 基尼指数

% 后剪枝
[~,~,~,best_level] = cvLoss(tree, 'SubTrees', 'all', 'TreeSize', 'min');
pruned_tree = prune(tree, 'Level', best_level);

%% 性能评估
Y_pred = predict(pruned_tree, X_test);
accuracy = sum(Y_pred == Y_test)/numel(Y_test);
confusionchart(Y_test, Y_pred);
disp(['准确率: ', num2str(accuracy*100), '%']);

参考代码 基于matlab的机器学习中决策树算法 www.youwenfan.com/contentcnl/81917.html

八、总结

MATLAB的决策树实现具备以下优势:

  1. 高效性:内置优化算法(如CART)支持快速训练。
  2. 可解释性:通过可视化直观展示决策逻辑。
  3. 扩展性:支持集成学习(随机森林)与高级参数调优。

应用场景

  • 医疗诊断(如癌症分类)
  • 工业故障检测(传感器数据分析)
  • 客户分群(零售业RFM模型)
posted @ 2025-11-21 09:57  csoe9999  阅读(0)  评论(0)    收藏  举报