本站原创文章,转载请说明来自《老饼讲解-BP神经网络》https://www.bbbdata.com/

LVQ神经网络是用于样本分类的一个常用算法,本文先简单回顾LVQ神经网络是什么,然后展示如何用matlab工具箱来训练一个LVQ神经网络

一. LVQ神经网络简介

本节回顾LVQ神经网络的思想和关键知识

1.1 LVQ神经网络是什么

LVQ用于解决分类问题,它先对每个类别都初始化一些类别判别中心点,然后通过训练来调整这些类别判别中心的位置,使它们能较好地识别训练样本

image.png
这样,来了新样本,只要判断新样本离哪个聚类中心点近,就判断样本属于该聚类中心点所代表的类别

1.2 LVQ神经网络的拓扑表示

LVQ一般用一个三层神经网络来表示,
它的拓扑结构如下:

image.png

其中,每个隐层节点代表着一个类别判别中心,它与输入层的权重\(W^{21}\)就是它的位置,
它的输出层的连接\(W^{32}\)代表着它是哪个类别的判别中心

例如,某个隐节点的输入权重为[0.3 0.5],输出权重为[0 1]
则代表它的位置为[0.3,0.5], 是类别1的判别中心

二. 如何使用matlab训练一个LVQ神经网络

本节讲解如何用matlab工具箱来训练一个LVQ神经网络

2.1 matlab工具箱实现LVQ的代码

下面以一个例子,讲述如何用matlab工具箱实现LVQ神经网络

代码如下:

%代码说明:matlab工具箱训练一个LVQ神经网络
%来自《老饼讲解神经网络》www.bbbdata.com ,matlab版本:2018a
%数据准备
clear all ;close all 
rand('seed',70)
P = [-3 -2 -2  0  0.5  -0.5  0 +2 +2 +3; ...
    0 +1 -1 +2 +1 -1 -2 +1 -1  0];                    % 输入数据
Tc = [1 1 1 2 2 2 2 1 1 1];                           % 输出类别
T = ind2vec(Tc);                                      % 将输出转为one-hot编码(代表类别的01向量)

%网络训练
net = newlvq(P,4,[0.5 ,0.5],0.01,'learnlv1');         % 建立一个LVQ神经网络,用lvq1规则训练
net = train(net,P,T);                                 % 训练神经网络
%预测
Y = sim(net,P);                                       % 预测(one-hot形式)
Yc = vec2ind(Y);                                      % 将one-hot编码形式转回类别编号形式
% 提取出各个类别的判别中心                            
c       = net.iw{1,1};                                % 中心
c_class = net.lw{2,1};                                % 中心所属类别
c       = [vec2ind(c_class)',c]                       % 添加中心的类别标签


% -------绘制结果-----------------
figure
% 绘制原始数据
subplot(2,1,1)
plot(P(1,Tc==1),P(2,Tc==1),'o','MarkerEdgeColor','k','MarkerFaceColor','b','MarkerSize',10)
hold on 
plot(P(1,Tc==2),P(2,Tc==2),'o','MarkerEdgeColor','k','MarkerFaceColor','g','MarkerSize',10)
legend('类别1','类别2')
title('原始数据类别')
% 绘制预测结果
subplot(2,1,2)
plot(P(1,Yc==1),P(2,Yc==1),'o','MarkerEdgeColor','k','MarkerFaceColor','b','MarkerSize',10)
hold on 
plot(P(1,Yc==2),P(2,Yc==2),'o','MarkerEdgeColor','k','MarkerFaceColor','g','MarkerSize',10)
hold on 
% 绘制网络的隐节点(类别判别中心)
plot(c(:,2),c(:,3),'o','MarkerEdgeColor','k','MarkerFaceColor','y','MarkerSize',10)
for i = 1: size(c,1)
text(c(i,2)-0.050,c(i,3)+0.02,num2str(c(i,1)))
end
title('LVQ预测类别')

2.2 关键代码解说

其中,核心代码为net= newlvq(P,4,[0.5 ,0.5],0.01,'learnlv1'); ,它用于构建一个LVQ神经网络,

各全参数的含义如下

👉 1. P是训练数据的输入

👉 2. 4代表我们使用4个隐节点
(也就使用4个类别判别中心)

👉 3. [0.5,0.5]代表上述4个隐节点的类别分配比例(也就是类别1、类别2的判别中心各2个)

👉 4. 0.01是学习率

👉 5. 'learnlv1'则指定了训练方法

2.3 训练结果

运行上述代码,得到结果如下

image.png
可见,训练后的LVQ神经网络的预测类别与真实样本一致,
它已经可以准确的对训练样本进行分类

关于类别判别中心的位置

上述代码的还打印了类别判别中心的信息,如下

image.png

其中,每一行代表一个判别中心,第一列表示是哪一个类别的判别中心,第2、3列表示判别中心的坐标              

相关链接:

《老饼讲解-机器学习》:老饼讲解-机器学习教程-通俗易懂
《老饼讲解-神经网络》:老饼讲解-matlab神经网络-通俗易懂
《老饼讲解-神经网络》:老饼讲解-深度学习-通俗易懂

posted on 2024-06-22 05:41  老饼讲解机器学习  阅读(172)  评论(0)    收藏  举报