RNN 前向和梯度 matlab实现
转:https://www.cnblogs.com/YiXiaoZhou/p/6058890.html
BPTT,Back Propagation Through Time.
首先来看看怎么处理RNN。
RNN展开网络如下图

RNN展开结构.jpg

RNN节点结构.jpg
现令第t时刻的输入表示为
,隐层节点的输出为
,输出层的预测值
,输入到隐层的权重矩阵
,隐层自循环的权重矩阵
,隐层到输出层的权重矩阵
,对应的偏执向量分别表示为
,输入层的某一个节点使用i标识,如
,类似的隐层和输出层某一节点表示为
。这里我们仅以三层网络为例。
那么首先正向计算

其中
分别表示激活前对应的加权和,
表示激活函数。
然后看看误差如何传递。
假设真实的输出应该是
,那么误差可以定义为
,
是训练样本的index。整个网络的误差
我们将RNN再放大一些,看看细节

RNN节点内部连接.jpg
令
则

矩阵向量化表示

所以梯度为:

其中
是点乘符号,即对应元素乘。
代码实现:
我们可以注意到在计算梯度时需要用到的之前计算过的量,即需要保存的量包括,前向计算时隐层节点和输出节点的输出值,以及由
时刻累积的
。
人人都能用Python写出LSTM-RNN的代码![你的神经网络学习最佳起步]这篇文章里使用python实现了基本的RNN过程。代码功能是模拟二进制相加过程中的依次进位过程,代码很容易明白。
这里改写成matlab代码
function error = binaryRNN( )
largestNumber=256;
T=8;
dic=dec2bin(0:largestNumber-1)-'0';% 将uint8表示成二进制数组,这是一个查找表
%% 初始化参数
eta=0.1;% 学习步长
inputDim=2;% 输入维度
hiddenDim=16; %隐层节点个数
outputDim=1; % 输出层节点个数
W=rand(hiddenDim,outputDim)*2-1;% (-1,1)参数矩阵
U=rand(hiddenDim,hiddenDim)*2-1;% (-1,1)参数矩阵
V=rand(inputDim,hiddenDim)*2-1; % (-1,1)参数矩阵
delta_W=zeros(hiddenDim,outputDim); % 时刻间中间变量
delta_U=zeros(hiddenDim,hiddenDim);
delta_V=zeros(inputDim,hiddenDim);
error=0;
for p=1:10000
aInt=randi(largestNumber/2);
bInt=randi(largestNumber/2);
a=dic(aInt+1,:);
b=dic(bInt+1,:);
cInt=aInt+bInt;
c=dic(cInt+1,:);
y=zeros(1,T);
preh=zeros(1,hiddenDim);
hDic=zeros(T,hiddenDim);
%% 前向计算
for t=T:-1:1 % 注意应该从最低位计算,也就是二进制数组最右端开始计算
x=[a(t),b(t)];
h=sigmoid(x*V+preh*U);
y(t)=sigmoid(h*W);
hDic(t,:)=h;
preh=h;
end
err=y-c;
error=error+norm(err,2)/2;
next_delta_h=zeros(1,hiddenDim);
%% 反馈
for t=1:T
delta_y = err(t).*sigmoidOutput2d(y(t));
delta_h=(delta_y*W'+next_delta_h*U').*sigmoidOutput2d(hDic(t,:));
delta_W=delta_W+hDic(t,:)'*delta_y;
if t<T
delta_U=delta_U+hDic(t+1,:)'*delta_h;
end
delta_V=delta_V+[a(t),b(t)]'*delta_h;
next_delta_h=delta_h;
end
% 梯度下降
W=W-eta*delta_W;
U=U-eta*delta_U;
V=V-eta*delta_V;
delta_W=zeros(hiddenDim,outputDim);
delta_U=zeros(hiddenDim,hiddenDim);
delta_V=zeros(inputDim,hiddenDim);
if mod(p,1000)==0
fprintf('Samples:%d\n',p);
fprintf('True:%d\n',cInt);
fprintf('Predict:%d\n',bin2dec(int2str(round(y))));
fprintf('Error:%f\n',norm(err,2)/2);
end
end
end
function sx=sigmoid(x)
sx=1./(1+exp(-x));
end
function dx=sigmoidOutput2d(output)
dx=output.*(1-output);
end

浙公网安备 33010602011771号