基于MATLAB的贝叶斯网络鸢尾花分类实现
一、贝叶斯网络结构设计
%% 构建贝叶斯网络结构
net = bnt.BayesNet(); % 创建贝叶斯网络对象
% 添加节点
classNode = net.addNode('Class', 'discrete', [1 2 3]); % 类别节点(3种可能值)
for i = 1:4
featureNode = net.addNode(sprintf('Feat%d',i), 'continuous'); % 连续特征节点
net.addEdge(classNode, featureNode); % 类别节点指向特征节点
end
% 可视化网络结构
view(net);
二、核心代码
%% 数据加载与预处理
load fisheriris
X = meas; % 特征矩阵 (150x4)
Y = grp2idx(species); % 类别标签 (1: setosa, 2: versicolor, 3: virginica)
%% 参数学习
% 定义条件概率分布(CPD)
cpd_class = tabularCPD(net, 'Class', [0.3333 0.3333 0.3334]); % 均匀先验
% 定义高斯基函数CPD
for i = 1:4
featureName = sprintf('Feat%d',i);
mu = mean(X(:,i));
sigma = std(X(:,i));
cpd_feat = tabularCPD(net, featureName, ...
'Mean', mu, ...
'Variance', sigma^2);
net = addCPD(net, cpd_feat);
end
%% 模型验证
assert(net.checkModel()); % 验证网络结构
%% 分类推理
inference = varElimInfEngine(net);
% 训练集预测
Y_pred_train = zeros(size(Y));
for i = 1:size(X,1)
evidence = struct();
for j = 1:4
evidence.(sprintf('Feat%d',j)) = X(i,j);
end
[~, Y_pred_train(i)] = inference.evidence(evidence);
end
% 测试集预测
cv = cvpartition(size(X,1),'HoldOut',0.3);
X_test = X(cv.test,:);
Y_test = Y(cv.test,:);
Y_pred_test = zeros(size(X_test,1),1);
for i = 1:size(X_test,1)
evidence = struct();
for j = 1:4
evidence.(sprintf('Feat%d',j)) = X_test(i,j);
end
[~, Y_pred_test(i)] = inference.evidence(evidence);
end
%% 性能评估
train_acc = sum(Y_pred_train == Y)/length(Y);
test_acc = sum(Y_pred_test == Y_test)/length(Y_test);
fprintf('训练集准确率: %.2f%%
', train_acc*100);
fprintf('测试集准确率: %.2f%%
', test_acc*100);
三、优化
-
非高斯分布建模
% 使用核密度估计 cpd_feat = tabularCPD(net, 'Feat1', 'Kernel', 'Bandwidth', 0.5); -
特征间依赖关系
% 添加花瓣长度与花瓣宽度的依赖关系 net = addEdge(net, net.getNode('Feat2'), net.getNode('Feat3')); -
贝叶斯网络增强
% 添加环境特征节点 envNode = net.addNode('SoilpH', 'discrete', [4 5 6 7 8]); net = addEdge(net, envNode, net.getNode('Feat1'));
四、可视化分析
%% 绘制条件概率分布
figure;
plotCPD(net, 'Feat1');
title('花瓣长度条件概率分布');
%% 绘制决策边界
[x1Grid, x2Grid] = meshgrid(linspace(4,8,50), linspace(1.5,7,50));
XGrid = [x1Grid(:) x2Grid(:) zeros(size(XGrid))];
posterior = zeros(size(XGrid,1),3);
for i = 1:size(XGrid,1)
evidence = struct();
evidence.Feat1 = XGrid(i,1);
evidence.Feat2 = XGrid(i,2);
evidence.Feat3 = 0;
evidence.Feat4 = 0;
[~,posterior(i)] = inference.evidence(evidence);
end
figure;
contourf(x1Grid, x2Grid, reshape(posterior(:,2),size(x1Grid)), 'LineColor', 'none');
hold on;
gscatter(X(:,1), X(:,2), Y);
title('贝叶斯网络分类决策边界');
xlabel('花萼长度'); ylabel('花萼宽度');
colorbar;
参考代码 利用贝叶斯网络对IRIS数据(鸢尾花)进行分类 www.youwenfan.com/contentcnl/78526.html
五、完整代码结构
iris_bn/
├── data_loader.m # 数据加载与预处理
├── network_model.m # 网络结构定义
├── parameter_learning.m # 参数学习
├── inference.m # 推理模块
└── visualization.m # 可视化工具
六、注意事项
- 数据分布假设 默认假设特征服从高斯分布,可通过核密度估计改进
- 计算复杂度 节点数增加时计算量呈指数增长,建议使用树结构
- 连续特征处理 对非高斯分布特征建议使用混合高斯模型

浙公网安备 33010602011771号