MATLAB实现多棵树的集成分类器

MATLAB实现多棵树的集成分类器,包括随机森林、梯度提升树

集成树分类器概述

核心思想

集成树方法通过组合多个弱学习器(决策树)来构建强大的分类器:

方法 核心思想 优点
随机森林 多棵决策树并行训练,通过投票决定结果 抗过拟合,处理高维数据
梯度提升树 串行训练多棵树,每棵树学习前序树的残差 高精度,灵活处理各种数据
AdaBoost 调整样本权重,重点关注难分类样本 对噪声敏感,计算高效

代码

1. 随机森林分类器

classdef RandomForest < handle
    % 随机森林分类器
    
    properties
        numTrees          % 树的数量
        minLeafSize       % 最小叶子大小
        numFeaturesToSample % 每棵树采样的特征数
        trees            % 存储所有决策树
        featureImportance % 特征重要性
        classes          % 类别标签
    end
    
    methods
        function obj = RandomForest(numTrees, minLeafSize, numFeaturesToSample)
            % 构造函数
            if nargin < 3
                numFeaturesToSample = 'sqrt';
            end
            if nargin < 2
                minLeafSize = 1;
            end
            if nargin < 1
                numTrees = 100;
            end
            
            obj.numTrees = numTrees;
            obj.minLeafSize = minLeafSize;
            obj.numFeaturesToSample = numFeaturesToSample;
            obj.trees = cell(numTrees, 1);
        end
        
        function fit(obj, X, y)
            % 训练随机森林
            % X: 特征矩阵 (n_samples × n_features)
            % y: 标签向量 (n_samples × 1)
            
            [n_samples, n_features] = size(X);
            obj.classes = unique(y);
            
            % 确定每棵树使用的特征数
            if ischar(obj.numFeaturesToSample)
                if strcmp(obj.numFeaturesToSample, 'sqrt')
                    n_features_sample = round(sqrt(n_features));
                else
                    n_features_sample = round(log2(n_features));
                end
            else
                n_features_sample = min(obj.numFeaturesToSample, n_features);
            end
            
            fprintf('训练随机森林 (%d棵树, 每棵树使用%d个特征)...\n', ...
                    obj.numTrees, n_features_sample);
            
            % 并行训练多棵树
            parfor i = 1:obj.numTrees
                % 自助采样
                bootstrap_indices = randsample(n_samples, n_samples, true);
                X_bootstrap = X(bootstrap_indices, :);
                y_bootstrap = y(bootstrap_indices);
                
                % 特征采样
                feature_indices = randperm(n_features, n_features_sample);
                X_bootstrap_sampled = X_bootstrap(:, feature_indices);
                
                % 训练决策树
                tree = DecisionTree('min_leaf_size', obj.minLeafSize);
                tree.fit(X_bootstrap_sampled, y_bootstrap, feature_indices);
                
                obj.trees{i} = tree;
                
                if mod(i, 10) == 0
                    fprintf('已完成 %d/%d 棵树训练\n', i, obj.numTrees);
                end
            end
            
            % 计算特征重要性
            obj.calculateFeatureImportance(X, y, n_features);
        end
        
        function predictions = predict(obj, X)
            % 预测
            [n_samples, ~] = size(X);
            n_classes = length(obj.classes);
            
            % 收集所有树的预测
            tree_predictions = zeros(n_samples, obj.numTrees);
            
            parfor i = 1:obj.numTrees
                tree_predictions(:, i) = obj.trees{i}.predict(X);
            end
            
            % 多数投票
            predictions = mode(tree_predictions, 2);
        end
        
        function probabilities = predict_proba(obj, X)
            % 预测概率
            [n_samples, ~] = size(X);
            n_classes = length(obj.classes);
            
            probabilities = zeros(n_samples, n_classes);
            tree_predictions = zeros(n_samples, obj.numTrees);
            
            parfor i = 1:obj.numTrees
                tree_predictions(:, i) = obj.trees{i}.predict(X);
            end
            
            % 计算每个类别的投票比例
            for i = 1:n_samples
                for j = 1:n_classes
                    probabilities(i, j) = sum(tree_predictions(i, :) == obj.classes(j)) / obj.numTrees;
                end
            end
        end
        
        function calculateFeatureImportance(obj, X, y, n_features)
            % 计算特征重要性(基于基尼重要性)
            obj.featureImportance = zeros(1, n_features);
            
            for i = 1:obj.numTrees
                tree = obj.trees{i};
                for j = 1:length(tree.feature_importance)
                    feature_idx = tree.feature_indices(j);
                    obj.featureImportance(feature_idx) = ...
                        obj.featureImportance(feature_idx) + tree.feature_importance(j);
                end
            end
            
            obj.featureImportance = obj.featureImportance / obj.numTrees;
            obj.featureImportance = obj.featureImportance / sum(obj.featureImportance);
        end
        
        function plotFeatureImportance(obj, feature_names)
            % 绘制特征重要性图
            if nargin < 2
                feature_names = strcat('Feature_', string(1:length(obj.featureImportance)));
            end
            
            [~, sorted_idx] = sort(obj.featureImportance, 'descend');
            
            figure;
            barh(obj.featureImportance(sorted_idx(1:min(15, end))));
            set(gca, 'YTickLabel', feature_names(sorted_idx(1:min(15, end))));
            xlabel('特征重要性');
            title('随机森林特征重要性排名 (Top 15)');
            grid on;
        end
    end
end

2. 决策树基础类

classdef DecisionTree < handle
    % 决策树分类器(CART算法)
    
    properties
        min_leaf_size     % 最小叶子大小
        max_depth         % 最大深度
        root              % 根节点
        feature_indices   % 使用的特征索引
        feature_importance % 特征重要性
    end
    
    methods
        function obj = DecisionTree(varargin)
            % 构造函数
            p = inputParser;
            addParameter(p, 'min_leaf_size', 1);
            addParameter(p, 'max_depth', 20);
            parse(p, varargin{:});
            
            obj.min_leaf_size = p.Results.min_leaf_size;
            obj.max_depth = p.Results.max_depth;
        end
        
        function fit(obj, X, y, feature_indices)
            % 训练决策树
            if nargin < 4
                feature_indices = 1:size(X, 2);
            end
            obj.feature_indices = feature_indices;
            obj.feature_importance = zeros(1, length(feature_indices));
            
            obj.root = obj.build_tree(X, y, 0);
        end
        
        function node = build_tree(obj, X, y, depth)
            % 递归构建决策树
            n_samples = size(X, 1);
            n_features = size(X, 2);
            
            % 终止条件
            if depth >= obj.max_depth || n_samples <= obj.min_leaf_size || all(y == y(1))
                node = TreeNode();
                node.is_leaf = true;
                node.prediction = mode(y);
                node.probability = sum(y == node.prediction) / n_samples;
                return;
            end
            
            % 寻找最佳分割
            [best_feature, best_value, best_gain] = obj.find_best_split(X, y);
            
            if best_gain == 0
                node = TreeNode();
                node.is_leaf = true;
                node.prediction = mode(y);
                node.probability = sum(y == node.prediction) / n_samples;
                return;
            end
            
            % 创建内部节点
            node = TreeNode();
            node.is_leaf = false;
            node.feature_index = best_feature;
            node.threshold = best_value;
            node.feature_importance = best_gain * n_samples;
            
            % 更新特征重要性
            if best_feature > 0
                obj.feature_importance(best_feature) = ...
                    obj.feature_importance(best_feature) + node.feature_importance;
            end
            
            % 分割数据
            left_mask = X(:, best_feature) <= best_value;
            right_mask = ~left_mask;
            
            % 递归构建子树
            node.left = obj.build_tree(X(left_mask, :), y(left_mask), depth + 1);
            node.right = obj.build_tree(X(right_mask, :), y(right_mask), depth + 1);
        end
        
        function [best_feature, best_value, best_gain] = find_best_split(obj, X, y)
            % 寻找最佳分割点
            n_samples = size(X, 1);
            n_features = size(X, 2);
            
            best_gain = 0;
            best_feature = 0;
            best_value = 0;
            
            current_gini = obj.gini_impurity(y);
            
            for feature = 1:n_features
                % 对每个特征值尝试分割
                feature_values = unique(X(:, feature));
                
                for i = 1:length(feature_values)
                    value = feature_values(i);
                    
                    left_mask = X(:, feature) <= value;
                    right_mask = ~left_mask;
                    
                    if sum(left_mask) < obj.min_leaf_size || sum(right_mask) < obj.min_leaf_size
                        continue;
                    end
                    
                    % 计算基尼增益
                    gini_gain = current_gini - ...
                        (sum(left_mask)/n_samples * obj.gini_impurity(y(left_mask)) + ...
                         sum(right_mask)/n_samples * obj.gini_impurity(y(right_mask)));
                    
                    if gini_gain > best_gain
                        best_gain = gini_gain;
                        best_feature = feature;
                        best_value = value;
                    end
                end
            end
        end
        
        function gini = gini_impurity(~, y)
            % 计算基尼不纯度
            if isempty(y)
                gini = 0;
                return;
            end
            
            classes = unique(y);
            n = length(y);
            gini = 1;
            
            for i = 1:length(classes)
                p = sum(y == classes(i)) / n;
                gini = gini - p^2;
            end
        end
        
        function predictions = predict(obj, X)
            % 预测
            n_samples = size(X, 1);
            predictions = zeros(n_samples, 1);
            
            for i = 1:n_samples
                node = obj.root;
                while ~node.is_leaf
                    if X(i, node.feature_index) <= node.threshold
                        node = node.left;
                    else
                        node = node.right;
                    end
                end
                predictions(i) = node.prediction;
            end
        end
    end
end

classdef TreeNode < handle
    % 决策树节点
    properties
        is_leaf          % 是否为叶子节点
        feature_index    % 分割特征索引
        threshold        % 分割阈值
        left             % 左子树
        right            % 右子树
        prediction       % 预测值(叶子节点)
        probability      % 预测概率
        feature_importance % 特征重要性
    end
end

3. 梯度提升树(GBDT)

classdef GradientBoostingClassifier < handle
    % 梯度提升树分类器
    
    properties
        n_estimators     % 树的数量
        learning_rate    % 学习率
        max_depth        % 最大深度
        min_samples_split % 最小分割样本数
        trees           % 存储所有树
        initial_prediction % 初始预测
        classes         % 类别标签
    end
    
    methods
        function obj = GradientBoostingClassifier(n_estimators, learning_rate, max_depth)
            % 构造函数
            if nargin < 3
                max_depth = 3;
            end
            if nargin < 2
                learning_rate = 0.1;
            end
            if nargin < 1
                n_estimators = 100;
            end
            
            obj.n_estimators = n_estimators;
            obj.learning_rate = learning_rate;
            obj.max_depth = max_depth;
            obj.trees = cell(n_estimators, 1);
        end
        
        function fit(obj, X, y)
            % 训练梯度提升树
            obj.classes = unique(y);
            n_samples = size(X, 1);
            n_classes = length(obj.classes);
            
            % 将标签转换为one-hot编码
            y_onehot = zeros(n_samples, n_classes);
            for i = 1:n_classes
                y_onehot(:, i) = (y == obj.classes(i));
            end
            
            % 初始预测(对数几率)
            obj.initial_prediction = log(mean(y_onehot) ./ (1 - mean(y_onehot)));
            F = repmat(obj.initial_prediction, n_samples, 1);
            
            fprintf('训练梯度提升树 (%d棵树)...\n', obj.n_estimators);
            
            for t = 1:obj.n_estimators
                % 计算负梯度(残差)
                probabilities = obj.softmax(F);
                residuals = y_onehot - probabilities;
                
                % 为每个类别训练一棵树
                tree_group = cell(n_classes, 1);
                
                for k = 1:n_classes
                    % 训练回归树来拟合残差
                    tree = RegressionTree('max_depth', obj.max_depth);
                    tree.fit(X, residuals(:, k));
                    tree_group{k} = tree;
                end
                
                obj.trees{t} = tree_group;
                
                % 更新预测
                for k = 1:n_classes
                    F(:, k) = F(:, k) + obj.learning_rate * tree_group{k}.predict(X);
                end
                
                if mod(t, 10) == 0
                    current_prob = obj.softmax(F);
                    current_pred = obj.classes(:, argmax(current_prob, 2));
                    accuracy = mean(current_pred == y);
                    fprintf('树 %d/%d, 训练准确率: %.4f\n', t, obj.n_estimators, accuracy);
                end
            end
        end
        
        function probabilities = softmax(~, X)
            % Softmax函数
            exp_X = exp(X - max(X, [], 2));
            probabilities = exp_X ./ sum(exp_X, 2);
        end
        
        function predictions = predict(obj, X)
            % 预测
            probabilities = obj.predict_proba(X);
            [~, class_idx] = max(probabilities, [], 2);
            predictions = obj.classes(class_idx);
        end
        
        function probabilities = predict_proba(obj, X)
            % 预测概率
            n_samples = size(X, 1);
            n_classes = length(obj.classes);
            
            F = repmat(obj.initial_prediction, n_samples, 1);
            
            for t = 1:obj.n_estimators
                tree_group = obj.trees{t};
                for k = 1:n_classes
                    F(:, k) = F(:, k) + obj.learning_rate * tree_group{k}.predict(X);
                end
            end
            
            probabilities = obj.softmax(F);
        end
    end
end

classdef RegressionTree < handle
    % 回归树(用于GBDT)
    
    properties
        max_depth
        min_samples_split
        root
    end
    
    methods
        function obj = RegressionTree(varargin)
            p = inputParser;
            addParameter(p, 'max_depth', 3);
            addParameter(p, 'min_samples_split', 2);
            parse(p, varargin{:});
            
            obj.max_depth = p.Results.max_depth;
            obj.min_samples_split = p.Results.min_samples_split;
        end
        
        function fit(obj, X, y)
            obj.root = obj.build_tree(X, y, 0);
        end
        
        function node = build_tree(obj, X, y, depth)
            n_samples = size(X, 1);
            
            node = RegTreeNode();
            node.prediction = mean(y);
            
            % 终止条件
            if depth >= obj.max_depth || n_samples <= obj.min_samples_split || var(y) < 1e-6
                return;
            end
            
            % 寻找最佳分割
            [best_feature, best_value, best_reduction] = obj.find_best_split(X, y);
            
            if best_reduction < 1e-6
                return;
            end
            
            % 分割数据
            left_mask = X(:, best_feature) <= best_value;
            right_mask = ~left_mask;
            
            if sum(left_mask) == 0 || sum(right_mask) == 0
                return;
            end
            
            node.feature_index = best_feature;
            node.threshold = best_value;
            node.left = obj.build_tree(X(left_mask, :), y(left_mask), depth + 1);
            node.right = obj.build_tree(X(right_mask, :), y(right_mask), depth + 1);
        end
        
        function [best_feature, best_value, best_reduction] = find_best_split(obj, X, y)
            n_samples = size(X, 1);
            n_features = size(X, 2);
            
            best_reduction = 0;
            best_feature = 0;
            best_value = 0;
            
            current_variance = var(y);
            
            for feature = 1:n_features
                feature_values = unique(X(:, feature));
                
                for i = 1:length(feature_values)
                    value = feature_values(i);
                    left_mask = X(:, feature) <= value;
                    right_mask = ~left_mask;
                    
                    if sum(left_mask) < 2 || sum(right_mask) < 2
                        continue;
                    end
                    
                    variance_reduction = current_variance - ...
                        (sum(left_mask)/n_samples * var(y(left_mask)) + ...
                         sum(right_mask)/n_samples * var(y(right_mask)));
                    
                    if variance_reduction > best_reduction
                        best_reduction = variance_reduction;
                        best_feature = feature;
                        best_value = value;
                    end
                end
            end
        end
        
        function predictions = predict(obj, X)
            n_samples = size(X, 1);
            predictions = zeros(n_samples, 1);
            
            for i = 1:n_samples
                node = obj.root;
                while ~isempty(node.feature_index)
                    if X(i, node.feature_index) <= node.threshold
                        node = node.left;
                    else
                        node = node.right;
                    end
                end
                predictions(i) = node.prediction;
            end
        end
    end
end

classdef RegTreeNode < handle
    properties
        feature_index
        threshold
        left
        right
        prediction
    end
end

4. 完整的集成分类器演示

function ensemble_classifier_demo()
    % 集成树分类器演示
    
    % 生成示例数据
    [X, y] = generate_sample_data();
    
    % 划分训练集和测试集
    rng(42);
    cv = cvpartition(y, 'HoldOut', 0.3);
    X_train = X(training(cv), :);
    y_train = y(training(cv));
    X_test = X(test(cv), :);
    y_test = y(test(cv));
    
    fprintf('数据信息: %d个样本, %d个特征, %d个类别\n', ...
            size(X, 1), size(X, 2), length(unique(y)));
    fprintf('训练集: %d个样本, 测试集: %d个样本\n', ...
            length(y_train), length(y_test));
    
    % 比较不同分类器
    classifiers = {
        struct('name', '随机森林', 'obj', RandomForest(100, 5, 'sqrt')), ...
        struct('name', '梯度提升树', 'obj', GradientBoostingClassifier(100, 0.1, 3)), ...
        struct('name', '单棵决策树', 'obj', DecisionTree('min_leaf_size', 5, 'max_depth', 10))
    };
    
    results = struct();
    
    figure('Position', [100, 100, 1200, 800]);
    
    for i = 1:length(classifiers)
        fprintf('\n=== 训练 %s ===\n', classifiers{i}.name);
        
        % 训练分类器
        tic;
        classifiers{i}.obj.fit(X_train, y_train);
        training_time = toc;
        
        % 预测
        y_pred = classifiers{i}.obj.predict(X_test);
        
        % 计算性能指标
        accuracy = mean(y_pred == y_test);
        cm = confusionmat(y_test, y_pred);
        precision = diag(cm) ./ sum(cm, 1)';
        recall = diag(cm) ./ sum(cm, 2);
        f1_score = 2 * (precision .* recall) ./ (precision + recall);
        
        % 存储结果
        results(i).name = classifiers{i}.name;
        results(i).accuracy = accuracy;
        results(i).precision = mean(precision, 'omitnan');
        results(i).recall = mean(recall, 'omitnan');
        results(i).f1_score = mean(f1_score, 'omitnan');
        results(i).training_time = training_time;
        results(i).confusion_matrix = cm;
        
        % 显示结果
        fprintf('准确率: %.4f\n', accuracy);
        fprintf('精确率: %.4f\n', results(i).precision);
        fprintf('召回率: %.4f\n', results(i).recall);
        fprintf('F1分数: %.4f\n', results(i).f1_score);
        fprintf('训练时间: %.2f秒\n', training_time);
        
        % 绘制混淆矩阵
        subplot(2, 3, i);
        plot_confusion_matrix(cm, unique(y_test), classifiers{i}.name);
        
        % 如果是随机森林,绘制特征重要性
        if i == 1 && isa(classifiers{i}.obj, 'RandomForest')
            subplot(2, 3, 4);
            classifiers{i}.obj.plotFeatureImportance();
        end
    end
    
    % 性能比较图
    subplot(2, 3, 5);
    metrics = [results.accuracy; results.precision; results.recall; results.f1_score];
    bar(metrics');
    set(gca, 'XTickLabel', {results.name});
    ylabel('分数');
    title('分类器性能比较');
    legend('准确率', '精确率', '召回率', 'F1分数', 'Location', 'best');
    grid on;
    
    % 训练时间比较
    subplot(2, 3, 6);
    bar([results.training_time]);
    set(gca, 'XTickLabel', {results.name});
    ylabel('时间 (秒)');
    title('训练时间比较');
    grid on;
    
    % 显示最佳分类器
    [~, best_idx] = max([results.accuracy]);
    fprintf('\n🎉 最佳分类器: %s (准确率: %.4f)\n', ...
            results(best_idx).name, results(best_idx).accuracy);
end

function [X, y] = generate_sample_data()
    % 生成分类示例数据
    rng(42);
    
    n_samples = 1000;
    n_features = 20;
    n_classes = 3;
    
    % 生成特征数据
    X = zeros(n_samples, n_features);
    
    % 为每个类别创建不同的分布
    for i = 1:n_classes
        class_start = floor((i-1) * n_samples / n_classes) + 1;
        class_end = floor(i * n_samples / n_classes);
        n_class_samples = class_end - class_start + 1;
        
        % 每个类别有不同的均值
        mean_vals = linspace(-2, 2, n_features) * i;
        X(class_start:class_end, :) = randn(n_class_samples, n_features) + mean_vals;
    end
    
    % 生成标签
    y = zeros(n_samples, 1);
    samples_per_class = floor(n_samples / n_classes);
    
    for i = 1:n_classes
        start_idx = (i-1) * samples_per_class + 1;
        end_idx = min(i * samples_per_class, n_samples);
        y(start_idx:end_idx) = i;
    end
    
    % 添加一些噪声特征
    X(:, end-2:end) = randn(n_samples, 3);
    
    % 打乱数据
    shuffle_idx = randperm(n_samples);
    X = X(shuffle_idx, :);
    y = y(shuffle_idx);
end

function plot_confusion_matrix(cm, class_labels, title_str)
    % 绘制混淆矩阵
    imagesc(cm);
    colorbar;
    
    % 添加数值标签
    [n_classes, ~] = size(cm);
    for i = 1:n_classes
        for j = 1:n_classes
            text(j, i, num2str(cm(i, j)), ...
                 'HorizontalAlignment', 'center', ...
                 'Color', ifelse(cm(i, j) > max(cm(:))/2, 'white', 'black'));
        end
    end
    
    set(gca, 'XTick', 1:n_classes, 'XTickLabel', class_labels);
    set(gca, 'YTick', 1:n_classes, 'YTickLabel', class_labels);
    xlabel('预测标签');
    ylabel('真实标签');
    title(title_str);
end

function result = ifelse(condition, true_val, false_val)
    % 简单的条件判断函数
    if condition
        result = true_val;
    else
        result = false_val;
    end
end

5. 交互式分类器GUI工具

function ensemble_classifier_gui()
    % 集成分类器GUI工具
    
    fig = figure('Name', '集成树分类器工具', ...
                'NumberTitle', 'off', ...
                'Position', [100, 100, 1400, 800]);
    
    % 创建控制面板
    create_control_panel(fig);
    
    % 创建结果显示区域
    create_display_areas(fig);
    
    % 初始化数据
    setappdata(fig, 'classifiers', []);
    setappdata(fig, 'results', []);
end

function create_control_panel(fig)
    % 创建控制面板
    
    % 分类器选择
    uicontrol('Parent', fig, ...
             'Style', 'text', ...
             'String', '选择分类器:', ...
             'Position', [30, 700, 100, 20], ...
             'FontWeight', 'bold');
    
    classifier_list = uicontrol('Parent', fig, ...
                               'Style', 'listbox', ...
                               'String', {'随机森林', '梯度提升树', '决策树'}, ...
                               'Position', [30, 600, 150, 100], ...
                               'Max', 3, ...
                               'Tag', 'classifier_list');
    
    % 参数设置
    uicontrol('Parent', fig, ...
             'Style', 'text', ...
             'String', '树的数量:', ...
             'Position', [30, 550, 100, 20]);
    
    uicontrol('Parent', fig, ...
             'Style', 'edit', ...
             'String', '100', ...
             'Position', [140, 550, 50, 20], ...
             'Tag', 'n_trees');
    
    uicontrol('Parent', fig, ...
             'Style', 'text', ...
             'String', '学习率 (GBDT):', ...
             'Position', [30, 520, 100, 20]);
    
    uicontrol('Parent', fig, ...
             'Style', 'edit', ...
             'String', '0.1', ...
             'Position', [140, 520, 50, 20], ...
             'Tag', 'learning_rate');
    
    % 数据生成按钮
    uicontrol('Parent', fig, ...
             'Style', 'pushbutton', ...
             'String', '生成示例数据', ...
             'Position', [30, 450, 170, 30], ...
             'Callback', @generate_data_callback);
    
    % 训练按钮
    uicontrol('Parent', fig, ...
             'Style', 'pushbutton', ...
             'String', '训练分类器', ...
             'Position', [30, 400, 170, 30], ...
             'Callback', @train_classifiers_callback);
    
    % 结果展示区域
    uicontrol('Parent', fig, ...
             'Style', 'text', ...
             'String', '训练结果:', ...
             'Position', [30, 350, 100, 20], ...
             'FontWeight', 'bold');
    
    uicontrol('Parent', fig, ...
             'Style', 'listbox', ...
             'String', {}, ...
             'Position', [30, 150, 170, 200], ...
             'Tag', 'result_display');
end

function create_display_areas(fig)
    % 创建结果显示区域
    
    % 混淆矩阵
    axes('Parent', fig, ...
        'Position', [0.25, 0.55, 0.2, 0.3], ...
        'Tag', 'confusion_axes');
    title('混淆矩阵');
    
    % 特征重要性
    axes('Parent', fig, ...
        'Position', [0.5, 0.55, 0.2, 0.3], ...
        'Tag', 'importance_axes');
    title('特征重要性');
    
    % 性能比较
    axes('Parent', fig, ...
        'Position', [0.75, 0.55, 0.2, 0.3], ...
        'Tag', 'performance_axes');
    title('性能比较');
    
    % ROC曲线(多分类)
    axes('Parent', fig, ...
        'Position', [0.25, 0.15, 0.2, 0.3], ...
        'Tag', 'roc_axes');
    title('ROC曲线');
    
    % 学习曲线
    axes('Parent', fig, ...
        'Position', [0.5, 0.15, 0.2, 0.3], ...
        'Tag', 'learning_axes');
    title('学习曲线');
    
    % 预测结果
    axes('Parent', fig, ...
        'Position', [0.75, 0.15, 0.2, 0.3], ...
        'Tag', 'prediction_axes');
    title('预测结果分布');
end

function generate_data_callback(~, ~)
    % 生成数据回调函数
    [X, y] = generate_sample_data();
    
    setappdata(gcf, 'X_data', X);
    setappdata(gcf, 'y_data', y);
    
    msgbox(sprintf('生成 %d 个样本, %d 个特征, %d 个类别的数据', ...
                  size(X, 1), size(X, 2), length(unique(y))), '数据生成完成');
end

function train_classifiers_callback(~, ~)
    % 训练分类器回调函数
    
    X = getappdata(gcf, 'X_data');
    y = getappdata(gcf, 'y_data');
    
    if isempty(X)
        errordlg('请先生成数据!', '错误');
        return;
    end
    
    % 获取选择的分类器
    classifier_list = findobj(gcf, 'Tag', 'classifier_list');
    selected_classifiers = classifier_list.Value;
    classifier_names = classifier_list.String;
    
    if isempty(selected_classifiers)
        errordlg('请选择至少一个分类器!', '错误');
        return;
    end
    
    % 获取参数
    n_trees = str2double(findobj(gcf, 'Tag', 'n_trees').String);
    learning_rate = str2double(findobj(gcf, 'Tag', 'learning_rate').String);
    
    % 划分训练测试集
    cv = cvpartition(y, 'HoldOut', 0.3);
    X_train = X(training(cv), :);
    y_train = y(training(cv));
    X_test = X(test(cv), :);
    y_test = y(test(cv));
    
    results = [];
    
    for i = 1:length(selected_classifiers)
        classifier_idx = selected_classifiers(i);
        classifier_name = classifier_names{classifier_idx};
        
        fprintf('训练 %s...\n', classifier_name);
        
        % 创建分类器
        switch classifier_name
            case '随机森林'
                classifier = RandomForest(n_trees, 5, 'sqrt');
            case '梯度提升树'
                classifier = GradientBoostingClassifier(n_trees, learning_rate, 3);
            case '决策树'
                classifier = DecisionTree('min_leaf_size', 5, 'max_depth', 10);
        end
        
        % 训练和预测
        tic;
        classifier.fit(X_train, y_train);
        training_time = toc;
        
        y_pred = classifier.predict(X_test);
        accuracy = mean(y_pred == y_test);
        
        % 存储结果
        results(i).name = classifier_name;
        results(i).classifier = classifier;
        results(i).accuracy = accuracy;
        results(i).training_time = training_time;
        results(i).predictions = y_pred;
        
        % 更新结果显示
        result_display = findobj(gcf, 'Tag', 'result_display');
        current_string = result_display.String;
        new_result = sprintf('%s: 准确率=%.4f, 时间=%.2fs', ...
                            classifier_name, accuracy, training_time);
        
        if isempty(current_string)
            result_display.String = {new_result};
        else
            result_display.String = [current_string; new_result];
        end
    end
    
    setappdata(gcf, 'results', results);
    setappdata(gcf, 'X_test', X_test);
    setappdata(gcf, 'y_test', y_test);
    
    % 可视化结果
    visualize_results(results, X_test, y_test);
end

function visualize_results(results, X_test, y_test)
    % 可视化训练结果
    
    if isempty(results)
        return;
    end
    
    % 性能比较
    performance_axes = findobj(gcf, 'Tag', 'performance_axes');
    axes(performance_axes);
    cla;
    
    accuracies = [results.accuracy];
    times = [results.training_time];
    
    subplot(1, 2, 1);
    bar(accuracies);
    set(gca, 'XTickLabel', {results.name});
    ylabel('准确率');
    title('分类器准确率比较');
    grid on;
    
    subplot(1, 2, 2);
    bar(times);
    set(gca, 'XTickLabel', {results.name});
    ylabel('训练时间 (秒)');
    title('训练时间比较');
    grid on;
    
    % 显示最佳分类器
    [best_accuracy, best_idx] = max(accuracies);
    fprintf('\n最佳分类器: %s (准确率: %.4f)\n', results(best_idx).name, best_accuracy);
end

使用

基本使用方法

% 运行完整演示
ensemble_classifier_demo();

% 或者单独使用随机森林
[X, y] = generate_sample_data();

% 划分训练测试集
cv = cvpartition(y, 'HoldOut', 0.3);
X_train = X(training(cv), :);
y_train = y(training(cv));
X_test = X(test(cv), :);
y_test = y(test(cv));

% 训练随机森林
rf = RandomForest(100, 5, 'sqrt');
rf.fit(X_train, y_train);

% 预测
y_pred = rf.predict(X_test);
accuracy = mean(y_pred == y_test);
fprintf('随机森林准确率: %.4f\n', accuracy);

% 查看特征重要性
rf.plotFeatureImportance();

启动GUI工具

% 启动交互式分类器工具
ensemble_classifier_gui();

参考代码 利用多棵树对样本进行训练并预测的分类器 www.3dddown.com/cna/63814.html

关键特性

  1. 多种集成方法:随机森林、梯度提升树、单决策树
  2. 完整评估指标:准确率、精确率、召回率、F1分数
  3. 特征重要性分析:识别关键特征
  4. 可视化分析:混淆矩阵、性能比较、学习曲线
  5. 并行计算:利用MATLAB并行计算加速训练
posted @ 2025-12-14 12:46  yijg9998  阅读(2)  评论(0)    收藏  举报