基于MATLAB的贝叶斯网络鸢尾花分类实现

一、贝叶斯网络结构设计

%% 构建贝叶斯网络结构
net = bnt.BayesNet();  % 创建贝叶斯网络对象

% 添加节点
classNode = net.addNode('Class', 'discrete', [1 2 3]);  % 类别节点(3种可能值)
for i = 1:4
    featureNode = net.addNode(sprintf('Feat%d',i), 'continuous');  % 连续特征节点
    net.addEdge(classNode, featureNode);  % 类别节点指向特征节点
end

% 可视化网络结构
view(net);

二、核心代码

%% 数据加载与预处理
load fisheriris
X = meas;  % 特征矩阵 (150x4)
Y = grp2idx(species);  % 类别标签 (1: setosa, 2: versicolor, 3: virginica)

%% 参数学习
% 定义条件概率分布(CPD)
cpd_class = tabularCPD(net, 'Class', [0.3333 0.3333 0.3334]);  % 均匀先验

% 定义高斯基函数CPD
for i = 1:4
    featureName = sprintf('Feat%d',i);
    mu = mean(X(:,i));
    sigma = std(X(:,i));
    cpd_feat = tabularCPD(net, featureName, ...
        'Mean', mu, ...
        'Variance', sigma^2);
    net = addCPD(net, cpd_feat);
end

%% 模型验证
assert(net.checkModel());  % 验证网络结构

%% 分类推理
inference = varElimInfEngine(net);

% 训练集预测
Y_pred_train = zeros(size(Y));
for i = 1:size(X,1)
    evidence = struct();
    for j = 1:4
        evidence.(sprintf('Feat%d',j)) = X(i,j);
    end
    [~, Y_pred_train(i)] = inference.evidence(evidence);
end

% 测试集预测
cv = cvpartition(size(X,1),'HoldOut',0.3);
X_test = X(cv.test,:);
Y_test = Y(cv.test,:);
Y_pred_test = zeros(size(X_test,1),1);

for i = 1:size(X_test,1)
    evidence = struct();
    for j = 1:4
        evidence.(sprintf('Feat%d',j)) = X_test(i,j);
    end
    [~, Y_pred_test(i)] = inference.evidence(evidence);
end

%% 性能评估
train_acc = sum(Y_pred_train == Y)/length(Y);
test_acc = sum(Y_pred_test == Y_test)/length(Y_test);

fprintf('训练集准确率: %.2f%%
', train_acc*100);
fprintf('测试集准确率: %.2f%%
', test_acc*100);

三、优化

  1. 非高斯分布建模

    % 使用核密度估计
    cpd_feat = tabularCPD(net, 'Feat1', 'Kernel', 'Bandwidth', 0.5);
    
  2. 特征间依赖关系

    % 添加花瓣长度与花瓣宽度的依赖关系
    net = addEdge(net, net.getNode('Feat2'), net.getNode('Feat3'));
    
  3. 贝叶斯网络增强

    % 添加环境特征节点
    envNode = net.addNode('SoilpH', 'discrete', [4 5 6 7 8]);
    net = addEdge(net, envNode, net.getNode('Feat1'));
    

四、可视化分析

%% 绘制条件概率分布
figure;
plotCPD(net, 'Feat1');
title('花瓣长度条件概率分布');

%% 绘制决策边界
[x1Grid, x2Grid] = meshgrid(linspace(4,8,50), linspace(1.5,7,50));
XGrid = [x1Grid(:) x2Grid(:) zeros(size(XGrid))];
posterior = zeros(size(XGrid,1),3);

for i = 1:size(XGrid,1)
    evidence = struct();
    evidence.Feat1 = XGrid(i,1);
    evidence.Feat2 = XGrid(i,2);
    evidence.Feat3 = 0;
    evidence.Feat4 = 0;
    [~,posterior(i)] = inference.evidence(evidence);
end

figure;
contourf(x1Grid, x2Grid, reshape(posterior(:,2),size(x1Grid)), 'LineColor', 'none');
hold on;
gscatter(X(:,1), X(:,2), Y);
title('贝叶斯网络分类决策边界');
xlabel('花萼长度'); ylabel('花萼宽度');
colorbar;

参考代码 利用贝叶斯网络对IRIS数据(鸢尾花)进行分类 www.youwenfan.com/contentcnl/78526.html

五、完整代码结构

iris_bn/
├── data_loader.m      # 数据加载与预处理
├── network_model.m    # 网络结构定义
├── parameter_learning.m # 参数学习
├── inference.m        # 推理模块
└── visualization.m    # 可视化工具

六、注意事项

  1. 数据分布假设 默认假设特征服从高斯分布,可通过核密度估计改进
  2. 计算复杂度 节点数增加时计算量呈指数增长,建议使用树结构
  3. 连续特征处理 对非高斯分布特征建议使用混合高斯模型
posted @ 2025-11-14 10:35  u95900090  阅读(10)  评论(0)    收藏  举报