MATLAB实现多棵树的集成分类器
MATLAB实现多棵树的集成分类器,包括随机森林、梯度提升树
集成树分类器概述
核心思想
集成树方法通过组合多个弱学习器(决策树)来构建强大的分类器:
| 方法 | 核心思想 | 优点 |
|---|---|---|
| 随机森林 | 多棵决策树并行训练,通过投票决定结果 | 抗过拟合,处理高维数据 |
| 梯度提升树 | 串行训练多棵树,每棵树学习前序树的残差 | 高精度,灵活处理各种数据 |
| AdaBoost | 调整样本权重,重点关注难分类样本 | 对噪声敏感,计算高效 |
代码
1. 随机森林分类器
classdef RandomForest < handle
% 随机森林分类器
properties
numTrees % 树的数量
minLeafSize % 最小叶子大小
numFeaturesToSample % 每棵树采样的特征数
trees % 存储所有决策树
featureImportance % 特征重要性
classes % 类别标签
end
methods
function obj = RandomForest(numTrees, minLeafSize, numFeaturesToSample)
% 构造函数
if nargin < 3
numFeaturesToSample = 'sqrt';
end
if nargin < 2
minLeafSize = 1;
end
if nargin < 1
numTrees = 100;
end
obj.numTrees = numTrees;
obj.minLeafSize = minLeafSize;
obj.numFeaturesToSample = numFeaturesToSample;
obj.trees = cell(numTrees, 1);
end
function fit(obj, X, y)
% 训练随机森林
% X: 特征矩阵 (n_samples × n_features)
% y: 标签向量 (n_samples × 1)
[n_samples, n_features] = size(X);
obj.classes = unique(y);
% 确定每棵树使用的特征数
if ischar(obj.numFeaturesToSample)
if strcmp(obj.numFeaturesToSample, 'sqrt')
n_features_sample = round(sqrt(n_features));
else
n_features_sample = round(log2(n_features));
end
else
n_features_sample = min(obj.numFeaturesToSample, n_features);
end
fprintf('训练随机森林 (%d棵树, 每棵树使用%d个特征)...\n', ...
obj.numTrees, n_features_sample);
% 并行训练多棵树
parfor i = 1:obj.numTrees
% 自助采样
bootstrap_indices = randsample(n_samples, n_samples, true);
X_bootstrap = X(bootstrap_indices, :);
y_bootstrap = y(bootstrap_indices);
% 特征采样
feature_indices = randperm(n_features, n_features_sample);
X_bootstrap_sampled = X_bootstrap(:, feature_indices);
% 训练决策树
tree = DecisionTree('min_leaf_size', obj.minLeafSize);
tree.fit(X_bootstrap_sampled, y_bootstrap, feature_indices);
obj.trees{i} = tree;
if mod(i, 10) == 0
fprintf('已完成 %d/%d 棵树训练\n', i, obj.numTrees);
end
end
% 计算特征重要性
obj.calculateFeatureImportance(X, y, n_features);
end
function predictions = predict(obj, X)
% 预测
[n_samples, ~] = size(X);
n_classes = length(obj.classes);
% 收集所有树的预测
tree_predictions = zeros(n_samples, obj.numTrees);
parfor i = 1:obj.numTrees
tree_predictions(:, i) = obj.trees{i}.predict(X);
end
% 多数投票
predictions = mode(tree_predictions, 2);
end
function probabilities = predict_proba(obj, X)
% 预测概率
[n_samples, ~] = size(X);
n_classes = length(obj.classes);
probabilities = zeros(n_samples, n_classes);
tree_predictions = zeros(n_samples, obj.numTrees);
parfor i = 1:obj.numTrees
tree_predictions(:, i) = obj.trees{i}.predict(X);
end
% 计算每个类别的投票比例
for i = 1:n_samples
for j = 1:n_classes
probabilities(i, j) = sum(tree_predictions(i, :) == obj.classes(j)) / obj.numTrees;
end
end
end
function calculateFeatureImportance(obj, X, y, n_features)
% 计算特征重要性(基于基尼重要性)
obj.featureImportance = zeros(1, n_features);
for i = 1:obj.numTrees
tree = obj.trees{i};
for j = 1:length(tree.feature_importance)
feature_idx = tree.feature_indices(j);
obj.featureImportance(feature_idx) = ...
obj.featureImportance(feature_idx) + tree.feature_importance(j);
end
end
obj.featureImportance = obj.featureImportance / obj.numTrees;
obj.featureImportance = obj.featureImportance / sum(obj.featureImportance);
end
function plotFeatureImportance(obj, feature_names)
% 绘制特征重要性图
if nargin < 2
feature_names = strcat('Feature_', string(1:length(obj.featureImportance)));
end
[~, sorted_idx] = sort(obj.featureImportance, 'descend');
figure;
barh(obj.featureImportance(sorted_idx(1:min(15, end))));
set(gca, 'YTickLabel', feature_names(sorted_idx(1:min(15, end))));
xlabel('特征重要性');
title('随机森林特征重要性排名 (Top 15)');
grid on;
end
end
end
2. 决策树基础类
classdef DecisionTree < handle
% 决策树分类器(CART算法)
properties
min_leaf_size % 最小叶子大小
max_depth % 最大深度
root % 根节点
feature_indices % 使用的特征索引
feature_importance % 特征重要性
end
methods
function obj = DecisionTree(varargin)
% 构造函数
p = inputParser;
addParameter(p, 'min_leaf_size', 1);
addParameter(p, 'max_depth', 20);
parse(p, varargin{:});
obj.min_leaf_size = p.Results.min_leaf_size;
obj.max_depth = p.Results.max_depth;
end
function fit(obj, X, y, feature_indices)
% 训练决策树
if nargin < 4
feature_indices = 1:size(X, 2);
end
obj.feature_indices = feature_indices;
obj.feature_importance = zeros(1, length(feature_indices));
obj.root = obj.build_tree(X, y, 0);
end
function node = build_tree(obj, X, y, depth)
% 递归构建决策树
n_samples = size(X, 1);
n_features = size(X, 2);
% 终止条件
if depth >= obj.max_depth || n_samples <= obj.min_leaf_size || all(y == y(1))
node = TreeNode();
node.is_leaf = true;
node.prediction = mode(y);
node.probability = sum(y == node.prediction) / n_samples;
return;
end
% 寻找最佳分割
[best_feature, best_value, best_gain] = obj.find_best_split(X, y);
if best_gain == 0
node = TreeNode();
node.is_leaf = true;
node.prediction = mode(y);
node.probability = sum(y == node.prediction) / n_samples;
return;
end
% 创建内部节点
node = TreeNode();
node.is_leaf = false;
node.feature_index = best_feature;
node.threshold = best_value;
node.feature_importance = best_gain * n_samples;
% 更新特征重要性
if best_feature > 0
obj.feature_importance(best_feature) = ...
obj.feature_importance(best_feature) + node.feature_importance;
end
% 分割数据
left_mask = X(:, best_feature) <= best_value;
right_mask = ~left_mask;
% 递归构建子树
node.left = obj.build_tree(X(left_mask, :), y(left_mask), depth + 1);
node.right = obj.build_tree(X(right_mask, :), y(right_mask), depth + 1);
end
function [best_feature, best_value, best_gain] = find_best_split(obj, X, y)
% 寻找最佳分割点
n_samples = size(X, 1);
n_features = size(X, 2);
best_gain = 0;
best_feature = 0;
best_value = 0;
current_gini = obj.gini_impurity(y);
for feature = 1:n_features
% 对每个特征值尝试分割
feature_values = unique(X(:, feature));
for i = 1:length(feature_values)
value = feature_values(i);
left_mask = X(:, feature) <= value;
right_mask = ~left_mask;
if sum(left_mask) < obj.min_leaf_size || sum(right_mask) < obj.min_leaf_size
continue;
end
% 计算基尼增益
gini_gain = current_gini - ...
(sum(left_mask)/n_samples * obj.gini_impurity(y(left_mask)) + ...
sum(right_mask)/n_samples * obj.gini_impurity(y(right_mask)));
if gini_gain > best_gain
best_gain = gini_gain;
best_feature = feature;
best_value = value;
end
end
end
end
function gini = gini_impurity(~, y)
% 计算基尼不纯度
if isempty(y)
gini = 0;
return;
end
classes = unique(y);
n = length(y);
gini = 1;
for i = 1:length(classes)
p = sum(y == classes(i)) / n;
gini = gini - p^2;
end
end
function predictions = predict(obj, X)
% 预测
n_samples = size(X, 1);
predictions = zeros(n_samples, 1);
for i = 1:n_samples
node = obj.root;
while ~node.is_leaf
if X(i, node.feature_index) <= node.threshold
node = node.left;
else
node = node.right;
end
end
predictions(i) = node.prediction;
end
end
end
end
classdef TreeNode < handle
% 决策树节点
properties
is_leaf % 是否为叶子节点
feature_index % 分割特征索引
threshold % 分割阈值
left % 左子树
right % 右子树
prediction % 预测值(叶子节点)
probability % 预测概率
feature_importance % 特征重要性
end
end
3. 梯度提升树(GBDT)
classdef GradientBoostingClassifier < handle
% 梯度提升树分类器
properties
n_estimators % 树的数量
learning_rate % 学习率
max_depth % 最大深度
min_samples_split % 最小分割样本数
trees % 存储所有树
initial_prediction % 初始预测
classes % 类别标签
end
methods
function obj = GradientBoostingClassifier(n_estimators, learning_rate, max_depth)
% 构造函数
if nargin < 3
max_depth = 3;
end
if nargin < 2
learning_rate = 0.1;
end
if nargin < 1
n_estimators = 100;
end
obj.n_estimators = n_estimators;
obj.learning_rate = learning_rate;
obj.max_depth = max_depth;
obj.trees = cell(n_estimators, 1);
end
function fit(obj, X, y)
% 训练梯度提升树
obj.classes = unique(y);
n_samples = size(X, 1);
n_classes = length(obj.classes);
% 将标签转换为one-hot编码
y_onehot = zeros(n_samples, n_classes);
for i = 1:n_classes
y_onehot(:, i) = (y == obj.classes(i));
end
% 初始预测(对数几率)
obj.initial_prediction = log(mean(y_onehot) ./ (1 - mean(y_onehot)));
F = repmat(obj.initial_prediction, n_samples, 1);
fprintf('训练梯度提升树 (%d棵树)...\n', obj.n_estimators);
for t = 1:obj.n_estimators
% 计算负梯度(残差)
probabilities = obj.softmax(F);
residuals = y_onehot - probabilities;
% 为每个类别训练一棵树
tree_group = cell(n_classes, 1);
for k = 1:n_classes
% 训练回归树来拟合残差
tree = RegressionTree('max_depth', obj.max_depth);
tree.fit(X, residuals(:, k));
tree_group{k} = tree;
end
obj.trees{t} = tree_group;
% 更新预测
for k = 1:n_classes
F(:, k) = F(:, k) + obj.learning_rate * tree_group{k}.predict(X);
end
if mod(t, 10) == 0
current_prob = obj.softmax(F);
current_pred = obj.classes(:, argmax(current_prob, 2));
accuracy = mean(current_pred == y);
fprintf('树 %d/%d, 训练准确率: %.4f\n', t, obj.n_estimators, accuracy);
end
end
end
function probabilities = softmax(~, X)
% Softmax函数
exp_X = exp(X - max(X, [], 2));
probabilities = exp_X ./ sum(exp_X, 2);
end
function predictions = predict(obj, X)
% 预测
probabilities = obj.predict_proba(X);
[~, class_idx] = max(probabilities, [], 2);
predictions = obj.classes(class_idx);
end
function probabilities = predict_proba(obj, X)
% 预测概率
n_samples = size(X, 1);
n_classes = length(obj.classes);
F = repmat(obj.initial_prediction, n_samples, 1);
for t = 1:obj.n_estimators
tree_group = obj.trees{t};
for k = 1:n_classes
F(:, k) = F(:, k) + obj.learning_rate * tree_group{k}.predict(X);
end
end
probabilities = obj.softmax(F);
end
end
end
classdef RegressionTree < handle
% 回归树(用于GBDT)
properties
max_depth
min_samples_split
root
end
methods
function obj = RegressionTree(varargin)
p = inputParser;
addParameter(p, 'max_depth', 3);
addParameter(p, 'min_samples_split', 2);
parse(p, varargin{:});
obj.max_depth = p.Results.max_depth;
obj.min_samples_split = p.Results.min_samples_split;
end
function fit(obj, X, y)
obj.root = obj.build_tree(X, y, 0);
end
function node = build_tree(obj, X, y, depth)
n_samples = size(X, 1);
node = RegTreeNode();
node.prediction = mean(y);
% 终止条件
if depth >= obj.max_depth || n_samples <= obj.min_samples_split || var(y) < 1e-6
return;
end
% 寻找最佳分割
[best_feature, best_value, best_reduction] = obj.find_best_split(X, y);
if best_reduction < 1e-6
return;
end
% 分割数据
left_mask = X(:, best_feature) <= best_value;
right_mask = ~left_mask;
if sum(left_mask) == 0 || sum(right_mask) == 0
return;
end
node.feature_index = best_feature;
node.threshold = best_value;
node.left = obj.build_tree(X(left_mask, :), y(left_mask), depth + 1);
node.right = obj.build_tree(X(right_mask, :), y(right_mask), depth + 1);
end
function [best_feature, best_value, best_reduction] = find_best_split(obj, X, y)
n_samples = size(X, 1);
n_features = size(X, 2);
best_reduction = 0;
best_feature = 0;
best_value = 0;
current_variance = var(y);
for feature = 1:n_features
feature_values = unique(X(:, feature));
for i = 1:length(feature_values)
value = feature_values(i);
left_mask = X(:, feature) <= value;
right_mask = ~left_mask;
if sum(left_mask) < 2 || sum(right_mask) < 2
continue;
end
variance_reduction = current_variance - ...
(sum(left_mask)/n_samples * var(y(left_mask)) + ...
sum(right_mask)/n_samples * var(y(right_mask)));
if variance_reduction > best_reduction
best_reduction = variance_reduction;
best_feature = feature;
best_value = value;
end
end
end
end
function predictions = predict(obj, X)
n_samples = size(X, 1);
predictions = zeros(n_samples, 1);
for i = 1:n_samples
node = obj.root;
while ~isempty(node.feature_index)
if X(i, node.feature_index) <= node.threshold
node = node.left;
else
node = node.right;
end
end
predictions(i) = node.prediction;
end
end
end
end
classdef RegTreeNode < handle
properties
feature_index
threshold
left
right
prediction
end
end
4. 完整的集成分类器演示
function ensemble_classifier_demo()
% 集成树分类器演示
% 生成示例数据
[X, y] = generate_sample_data();
% 划分训练集和测试集
rng(42);
cv = cvpartition(y, 'HoldOut', 0.3);
X_train = X(training(cv), :);
y_train = y(training(cv));
X_test = X(test(cv), :);
y_test = y(test(cv));
fprintf('数据信息: %d个样本, %d个特征, %d个类别\n', ...
size(X, 1), size(X, 2), length(unique(y)));
fprintf('训练集: %d个样本, 测试集: %d个样本\n', ...
length(y_train), length(y_test));
% 比较不同分类器
classifiers = {
struct('name', '随机森林', 'obj', RandomForest(100, 5, 'sqrt')), ...
struct('name', '梯度提升树', 'obj', GradientBoostingClassifier(100, 0.1, 3)), ...
struct('name', '单棵决策树', 'obj', DecisionTree('min_leaf_size', 5, 'max_depth', 10))
};
results = struct();
figure('Position', [100, 100, 1200, 800]);
for i = 1:length(classifiers)
fprintf('\n=== 训练 %s ===\n', classifiers{i}.name);
% 训练分类器
tic;
classifiers{i}.obj.fit(X_train, y_train);
training_time = toc;
% 预测
y_pred = classifiers{i}.obj.predict(X_test);
% 计算性能指标
accuracy = mean(y_pred == y_test);
cm = confusionmat(y_test, y_pred);
precision = diag(cm) ./ sum(cm, 1)';
recall = diag(cm) ./ sum(cm, 2);
f1_score = 2 * (precision .* recall) ./ (precision + recall);
% 存储结果
results(i).name = classifiers{i}.name;
results(i).accuracy = accuracy;
results(i).precision = mean(precision, 'omitnan');
results(i).recall = mean(recall, 'omitnan');
results(i).f1_score = mean(f1_score, 'omitnan');
results(i).training_time = training_time;
results(i).confusion_matrix = cm;
% 显示结果
fprintf('准确率: %.4f\n', accuracy);
fprintf('精确率: %.4f\n', results(i).precision);
fprintf('召回率: %.4f\n', results(i).recall);
fprintf('F1分数: %.4f\n', results(i).f1_score);
fprintf('训练时间: %.2f秒\n', training_time);
% 绘制混淆矩阵
subplot(2, 3, i);
plot_confusion_matrix(cm, unique(y_test), classifiers{i}.name);
% 如果是随机森林,绘制特征重要性
if i == 1 && isa(classifiers{i}.obj, 'RandomForest')
subplot(2, 3, 4);
classifiers{i}.obj.plotFeatureImportance();
end
end
% 性能比较图
subplot(2, 3, 5);
metrics = [results.accuracy; results.precision; results.recall; results.f1_score];
bar(metrics');
set(gca, 'XTickLabel', {results.name});
ylabel('分数');
title('分类器性能比较');
legend('准确率', '精确率', '召回率', 'F1分数', 'Location', 'best');
grid on;
% 训练时间比较
subplot(2, 3, 6);
bar([results.training_time]);
set(gca, 'XTickLabel', {results.name});
ylabel('时间 (秒)');
title('训练时间比较');
grid on;
% 显示最佳分类器
[~, best_idx] = max([results.accuracy]);
fprintf('\n🎉 最佳分类器: %s (准确率: %.4f)\n', ...
results(best_idx).name, results(best_idx).accuracy);
end
function [X, y] = generate_sample_data()
% 生成分类示例数据
rng(42);
n_samples = 1000;
n_features = 20;
n_classes = 3;
% 生成特征数据
X = zeros(n_samples, n_features);
% 为每个类别创建不同的分布
for i = 1:n_classes
class_start = floor((i-1) * n_samples / n_classes) + 1;
class_end = floor(i * n_samples / n_classes);
n_class_samples = class_end - class_start + 1;
% 每个类别有不同的均值
mean_vals = linspace(-2, 2, n_features) * i;
X(class_start:class_end, :) = randn(n_class_samples, n_features) + mean_vals;
end
% 生成标签
y = zeros(n_samples, 1);
samples_per_class = floor(n_samples / n_classes);
for i = 1:n_classes
start_idx = (i-1) * samples_per_class + 1;
end_idx = min(i * samples_per_class, n_samples);
y(start_idx:end_idx) = i;
end
% 添加一些噪声特征
X(:, end-2:end) = randn(n_samples, 3);
% 打乱数据
shuffle_idx = randperm(n_samples);
X = X(shuffle_idx, :);
y = y(shuffle_idx);
end
function plot_confusion_matrix(cm, class_labels, title_str)
% 绘制混淆矩阵
imagesc(cm);
colorbar;
% 添加数值标签
[n_classes, ~] = size(cm);
for i = 1:n_classes
for j = 1:n_classes
text(j, i, num2str(cm(i, j)), ...
'HorizontalAlignment', 'center', ...
'Color', ifelse(cm(i, j) > max(cm(:))/2, 'white', 'black'));
end
end
set(gca, 'XTick', 1:n_classes, 'XTickLabel', class_labels);
set(gca, 'YTick', 1:n_classes, 'YTickLabel', class_labels);
xlabel('预测标签');
ylabel('真实标签');
title(title_str);
end
function result = ifelse(condition, true_val, false_val)
% 简单的条件判断函数
if condition
result = true_val;
else
result = false_val;
end
end
5. 交互式分类器GUI工具
function ensemble_classifier_gui()
% 集成分类器GUI工具
fig = figure('Name', '集成树分类器工具', ...
'NumberTitle', 'off', ...
'Position', [100, 100, 1400, 800]);
% 创建控制面板
create_control_panel(fig);
% 创建结果显示区域
create_display_areas(fig);
% 初始化数据
setappdata(fig, 'classifiers', []);
setappdata(fig, 'results', []);
end
function create_control_panel(fig)
% 创建控制面板
% 分类器选择
uicontrol('Parent', fig, ...
'Style', 'text', ...
'String', '选择分类器:', ...
'Position', [30, 700, 100, 20], ...
'FontWeight', 'bold');
classifier_list = uicontrol('Parent', fig, ...
'Style', 'listbox', ...
'String', {'随机森林', '梯度提升树', '决策树'}, ...
'Position', [30, 600, 150, 100], ...
'Max', 3, ...
'Tag', 'classifier_list');
% 参数设置
uicontrol('Parent', fig, ...
'Style', 'text', ...
'String', '树的数量:', ...
'Position', [30, 550, 100, 20]);
uicontrol('Parent', fig, ...
'Style', 'edit', ...
'String', '100', ...
'Position', [140, 550, 50, 20], ...
'Tag', 'n_trees');
uicontrol('Parent', fig, ...
'Style', 'text', ...
'String', '学习率 (GBDT):', ...
'Position', [30, 520, 100, 20]);
uicontrol('Parent', fig, ...
'Style', 'edit', ...
'String', '0.1', ...
'Position', [140, 520, 50, 20], ...
'Tag', 'learning_rate');
% 数据生成按钮
uicontrol('Parent', fig, ...
'Style', 'pushbutton', ...
'String', '生成示例数据', ...
'Position', [30, 450, 170, 30], ...
'Callback', @generate_data_callback);
% 训练按钮
uicontrol('Parent', fig, ...
'Style', 'pushbutton', ...
'String', '训练分类器', ...
'Position', [30, 400, 170, 30], ...
'Callback', @train_classifiers_callback);
% 结果展示区域
uicontrol('Parent', fig, ...
'Style', 'text', ...
'String', '训练结果:', ...
'Position', [30, 350, 100, 20], ...
'FontWeight', 'bold');
uicontrol('Parent', fig, ...
'Style', 'listbox', ...
'String', {}, ...
'Position', [30, 150, 170, 200], ...
'Tag', 'result_display');
end
function create_display_areas(fig)
% 创建结果显示区域
% 混淆矩阵
axes('Parent', fig, ...
'Position', [0.25, 0.55, 0.2, 0.3], ...
'Tag', 'confusion_axes');
title('混淆矩阵');
% 特征重要性
axes('Parent', fig, ...
'Position', [0.5, 0.55, 0.2, 0.3], ...
'Tag', 'importance_axes');
title('特征重要性');
% 性能比较
axes('Parent', fig, ...
'Position', [0.75, 0.55, 0.2, 0.3], ...
'Tag', 'performance_axes');
title('性能比较');
% ROC曲线(多分类)
axes('Parent', fig, ...
'Position', [0.25, 0.15, 0.2, 0.3], ...
'Tag', 'roc_axes');
title('ROC曲线');
% 学习曲线
axes('Parent', fig, ...
'Position', [0.5, 0.15, 0.2, 0.3], ...
'Tag', 'learning_axes');
title('学习曲线');
% 预测结果
axes('Parent', fig, ...
'Position', [0.75, 0.15, 0.2, 0.3], ...
'Tag', 'prediction_axes');
title('预测结果分布');
end
function generate_data_callback(~, ~)
% 生成数据回调函数
[X, y] = generate_sample_data();
setappdata(gcf, 'X_data', X);
setappdata(gcf, 'y_data', y);
msgbox(sprintf('生成 %d 个样本, %d 个特征, %d 个类别的数据', ...
size(X, 1), size(X, 2), length(unique(y))), '数据生成完成');
end
function train_classifiers_callback(~, ~)
% 训练分类器回调函数
X = getappdata(gcf, 'X_data');
y = getappdata(gcf, 'y_data');
if isempty(X)
errordlg('请先生成数据!', '错误');
return;
end
% 获取选择的分类器
classifier_list = findobj(gcf, 'Tag', 'classifier_list');
selected_classifiers = classifier_list.Value;
classifier_names = classifier_list.String;
if isempty(selected_classifiers)
errordlg('请选择至少一个分类器!', '错误');
return;
end
% 获取参数
n_trees = str2double(findobj(gcf, 'Tag', 'n_trees').String);
learning_rate = str2double(findobj(gcf, 'Tag', 'learning_rate').String);
% 划分训练测试集
cv = cvpartition(y, 'HoldOut', 0.3);
X_train = X(training(cv), :);
y_train = y(training(cv));
X_test = X(test(cv), :);
y_test = y(test(cv));
results = [];
for i = 1:length(selected_classifiers)
classifier_idx = selected_classifiers(i);
classifier_name = classifier_names{classifier_idx};
fprintf('训练 %s...\n', classifier_name);
% 创建分类器
switch classifier_name
case '随机森林'
classifier = RandomForest(n_trees, 5, 'sqrt');
case '梯度提升树'
classifier = GradientBoostingClassifier(n_trees, learning_rate, 3);
case '决策树'
classifier = DecisionTree('min_leaf_size', 5, 'max_depth', 10);
end
% 训练和预测
tic;
classifier.fit(X_train, y_train);
training_time = toc;
y_pred = classifier.predict(X_test);
accuracy = mean(y_pred == y_test);
% 存储结果
results(i).name = classifier_name;
results(i).classifier = classifier;
results(i).accuracy = accuracy;
results(i).training_time = training_time;
results(i).predictions = y_pred;
% 更新结果显示
result_display = findobj(gcf, 'Tag', 'result_display');
current_string = result_display.String;
new_result = sprintf('%s: 准确率=%.4f, 时间=%.2fs', ...
classifier_name, accuracy, training_time);
if isempty(current_string)
result_display.String = {new_result};
else
result_display.String = [current_string; new_result];
end
end
setappdata(gcf, 'results', results);
setappdata(gcf, 'X_test', X_test);
setappdata(gcf, 'y_test', y_test);
% 可视化结果
visualize_results(results, X_test, y_test);
end
function visualize_results(results, X_test, y_test)
% 可视化训练结果
if isempty(results)
return;
end
% 性能比较
performance_axes = findobj(gcf, 'Tag', 'performance_axes');
axes(performance_axes);
cla;
accuracies = [results.accuracy];
times = [results.training_time];
subplot(1, 2, 1);
bar(accuracies);
set(gca, 'XTickLabel', {results.name});
ylabel('准确率');
title('分类器准确率比较');
grid on;
subplot(1, 2, 2);
bar(times);
set(gca, 'XTickLabel', {results.name});
ylabel('训练时间 (秒)');
title('训练时间比较');
grid on;
% 显示最佳分类器
[best_accuracy, best_idx] = max(accuracies);
fprintf('\n最佳分类器: %s (准确率: %.4f)\n', results(best_idx).name, best_accuracy);
end
使用
基本使用方法
% 运行完整演示
ensemble_classifier_demo();
% 或者单独使用随机森林
[X, y] = generate_sample_data();
% 划分训练测试集
cv = cvpartition(y, 'HoldOut', 0.3);
X_train = X(training(cv), :);
y_train = y(training(cv));
X_test = X(test(cv), :);
y_test = y(test(cv));
% 训练随机森林
rf = RandomForest(100, 5, 'sqrt');
rf.fit(X_train, y_train);
% 预测
y_pred = rf.predict(X_test);
accuracy = mean(y_pred == y_test);
fprintf('随机森林准确率: %.4f\n', accuracy);
% 查看特征重要性
rf.plotFeatureImportance();
启动GUI工具
% 启动交互式分类器工具
ensemble_classifier_gui();
参考代码 利用多棵树对样本进行训练并预测的分类器 www.3dddown.com/cna/63814.html
关键特性
- 多种集成方法:随机森林、梯度提升树、单决策树
- 完整评估指标:准确率、精确率、召回率、F1分数
- 特征重要性分析:识别关键特征
- 可视化分析:混淆矩阵、性能比较、学习曲线
- 并行计算:利用MATLAB并行计算加速训练
浙公网安备 33010602011771号