mthoutai

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

Neural Turing Machines-NTM系列(一)简述

Neural Turing Machines-NTM系列(一)简述

NTM是一种使用Neural Network为基础来实现传统图灵机的理论计算模型。利用该模型。能够通过训练的方式让系统“学会”具有时序关联的任务流。


论文:http://arxiv.org/abs/1410.5401
中文翻译:http://www.dengfanxin.cn/?p=60
ppt:http://llcao.net/cu-deeplearning15/presentation/NeuralTuringMachines.pdf
基于Theano的python语言实现1:https://github.com/shawntan/neural-turing-machines
基于Theano的python语言实现2:https://github.com/snipsco/ntm-lasagne
基于Torch的实现:https://github.com/kaishengtai/torch-ntm
基于Tensor Flow的实现:https://github.com/carpedm20/NTM-tensorflow
基于C#的实现:https://github.com/JanTkacik/NTM
基于JS语言的实现:https://github.com/gcgibson/NTM
GO语言实现:https://github.com/fumin/ntm
相关博客1:https://blog.wtf.sg/category/neural-turing-machines/
相关博客2 :https://medium.com/snips-ai/ntm-lasagne-a-library-for-neural-turing-machines-in-lasagne-2cdce6837315#.twrvqnda9
百度贴吧:http://tieba.baidu.com/p/3404779569
知乎中关于强人工智能的一些介绍:http://www.zhihu.com/question/34393952

1.图灵机

首先,我们来复习一下大学的知识,什么是图灵机呢?图灵机并非一个实体的计算机,而是一个理论的计算模型,由计算机技术先驱Turing在1936年提出(百度知道)
它包括例如以下基本元素:
TAPE:磁带。即记忆体(Memory)
HEAD:读写头,read or write TAPE上的内容
TABLE:一套控制规则,也叫控制器(Controller),依据机器当前状态和HEAD当前所读取的内容来决定下一步的操作。在NTM中,TABLE事实上模拟了大脑的工作记忆
register:状态寄存器。存储机器当前的状态。
例如以下图:
这里写图片描写叙述

2. 神经图灵机(NTM)

所谓的NTM,事实上就是使用NN来实现图灵机计算模型中的读写操作。其模型中的组件与图灵机同样。那么,NTM中是怎么实现的呢?

2.1 Reading 操作

假设t时刻的内存数据为Mt<script type="math/tex" id="MathJax-Element-1">M_t</script>,Mt<script type="math/tex" id="MathJax-Element-2">M_t</script>为一矩阵,大小为N×M<script type="math/tex" id="MathJax-Element-3">N \times M</script>,当中N为内存地址的数目,M为每一个内存地址向量的长度。

wt<script type="math/tex" id="MathJax-Element-4">w_t</script>为t时刻加于N个内存地址上的权值,wt<script type="math/tex" id="MathJax-Element-5">w_t</script>为一N维向量。且每一个分量wt(i)<script type="math/tex" id="MathJax-Element-6">w_t(i)</script>满足:
iwt(i)=1,i,0wt(i)1<script type="math/tex" id="MathJax-Element-7">\sum\limits_i w_t(i)=1,且\forall i,0≤w_t(i)≤1</script>
定义读取向量为rt<script type="math/tex" id="MathJax-Element-8">r_t</script>(即t时刻Read Head读取出来的内容),大小为M,且满足:
rt=iwt(i)Mt(i)<script type="math/tex" id="MathJax-Element-9">r_t=\sum\limits_i w_t(i)M_t(i)</script>
显然,rt<script type="math/tex" id="MathJax-Element-10">r_t</script>是Mt(i)<script type="math/tex" id="MathJax-Element-11">M_t(i)</script>的凸组合。

2.2 Writing 操作

写操作分解为顺序运行的两步:
1.擦除(erase)
2.加入(add)
wt<script type="math/tex" id="MathJax-Element-12">w_t</script>为Write Head发出的权值向量,et<script type="math/tex" id="MathJax-Element-13">e_t</script>为擦除向量,它们的全部分量值都在0,1之间,前一个时刻的Memory改动量为:
M˜t(i)=Mt1(i)[1wt(i)et]<script type="math/tex" id="MathJax-Element-14">\widetilde{M}_{t}(i)=M_{t-1}(i)\circ[1-w_t(i)e_t]</script>
式中的空心圆圈表示向量按元素逐个相乘(point-wise),显然,这里的et<script type="math/tex" id="MathJax-Element-15">e_t</script>指出了每一个分量将被擦除的量。举个简单的样例:
假设N=2,M=3<script type="math/tex" id="MathJax-Element-16">N=2,M=3</script>
Mt1=(142536)<script type="math/tex" id="MathJax-Element-17">M_{t-1}=\begin{equation} %開始数学环境 \left( %左括号 \begin{array}{ccc} %该矩阵一共3列。每一列都居中放置 1 & 2 & 3\\ %第一行元素 4 & 5 & 6\\ %第二行元素 \end{array} \right) %右括号 \end{equation}</script>
wt=[0.1,0.3,0.7]T<script type="math/tex" id="MathJax-Element-18">w_t=[0.1,0.3,0.7]^T</script>
et=[0.2,0.5,0.6]<script type="math/tex" id="MathJax-Element-19">e_t=[0.2,0.5,0.6]</script>
M˜t(1)=Mt1(1)[1wt(1)et]<script type="math/tex" id="MathJax-Element-20">\widetilde{M}_{t}(1)=M_{t-1}(1)\circ[1-w_t(1)e_t]</script>
=[1,2,3](10.1[0.2,0.5,0.6])=[1,2,3][0.98,0.95,0.94]=[0.98,1.9,1.88]<script type="math/tex" id="MathJax-Element-21">=[1,2,3]\circ(1-0.1*[0.2,0.5,0.6])=[1,2,3]\circ[0.98,0.95,0.94]=[0.98,1.9,1.88]</script>

假设不考虑wt<script type="math/tex" id="MathJax-Element-22">w_t</script>的影响,我们能够简单的觉得et<script type="math/tex" id="MathJax-Element-23">e_t</script>的值代表将擦除的量,比方上例中的[0.2,0.5,0.6],能够觉得内存中每一个分量将分别被擦去原值的0.2,0.5,0.6,而wt<script type="math/tex" id="MathJax-Element-24">w_t</script>相当于每一个分量将要被改动的权重。
假设要全然擦除一个分量。仅仅须要相应的wt(i)<script type="math/tex" id="MathJax-Element-25">w_t(i)</script>和et<script type="math/tex" id="MathJax-Element-26">e_t</script>都为1。

et<script type="math/tex" id="MathJax-Element-27">e_t</script>为0时,将不进行不论什么改动。
Write Head还须要生成一个长度为M<script type="math/tex" id="MathJax-Element-28">M</script>的add<script type="math/tex" id="MathJax-Element-29">add</script>向量at<script type="math/tex" id="MathJax-Element-30">a_t</script>。在erase<script type="math/tex" id="MathJax-Element-31">erase</script>操作运行完之后,它将被“加”到相应的内存地址中。


t时刻的内存值将为:
Mt(i)=M˜t(i)+wt(i)at<script type="math/tex" id="MathJax-Element-32">M_t(i)=\widetilde{M}_{t}(i)+w_t(i)a_t</script>
显然,erase<script type="math/tex" id="MathJax-Element-33">erase</script>和add<script type="math/tex" id="MathJax-Element-34">add</script>操作都是可微的,它们的组合操作writing也同样是可微的。writing能够对随意地址的元素值进行随意精度的改动。

3.NTM的寻址策略

有两种寻址策略,

3.1 Content-base(基于内容的寻址):

产生一个待查询的值kt<script type="math/tex" id="MathJax-Element-35">k_t</script>,将该与Mt<script type="math/tex" id="MathJax-Element-36">M_{t}</script>中的全部N<script type="math/tex" id="MathJax-Element-37">N</script>个地址的值进行比較,最相近的那个Mt(i)<script type="math/tex" id="MathJax-Element-38">M_t(i)</script>即为待查询的值。
首先,须要进行寻址操作的Head(Read or Write)生成一个长度为M的key vector:kt<script type="math/tex" id="MathJax-Element-39">k_t</script>,然后将kt<script type="math/tex" id="MathJax-Element-40">k_t</script>与每一个Mt(i)<script type="math/tex" id="MathJax-Element-41">M_t(i)</script>进行类似度比較(类似度计算函数为K[u,v]<script type="math/tex" id="MathJax-Element-42">K[u,v]</script>)。最后将生成一个归一化的权重向量wct<script type="math/tex" id="MathJax-Element-43">w_t^c</script>,计算公式例如以下:
wct(i)=eβtK[kt,Mt(i)]jeβtK[kt,Mt(j)]<script type="math/tex" id="MathJax-Element-44">w_t^c(i)=\frac {e^{\beta_tK[k_t,M_t(i)]}} {\sum\limits_je^{\beta_tK[k_t,M_t(j)]}}</script>
当中,βtβt>0<script type="math/tex" id="MathJax-Element-45">\beta_t满足\beta_t>0</script>是一个调节因子。用以调节寻址焦点的范围。βt<script type="math/tex" id="MathJax-Element-46">\beta_t</script>越大。函数的曲线变得越发陡峭,焦点的范围也就越小。
类似度函数这里取余弦类似度:K[u,v]=uv||u||||v||<script type="math/tex" id="MathJax-Element-47">K[u,v]=\frac{u \cdot v}{||u||\cdot||v||}</script>

3.2 Location-base(基于位置的寻址):

直接使用内存地址进行寻址,跟传统的计算机系统类似,controller给出要訪问的内存地址,Head直接定位到该地址所相应的内存位置。

对于一些对内容不敏感的操作,比方乘法函数f(x,y)=xy<script type="math/tex" id="MathJax-Element-108">f(x,y)=xy</script>,显然该操作并不局限于x,y的详细值。x,y的值是易变的。重要的是能够从指定的地址中把它们读出来。这类问题更适合採用Location-base的寻址方式。


基于地址的寻址方式能够同一时候提升简单顺序訪问和随机地址訪问的效率。我们通过对移位权值进行旋转操作来实现优化。

比如。当前权值聚焦在一个位置,旋转操作1将把焦点移向下一个位置,而一个负的旋转操作将会把焦点移向上一个位置。


在旋转操作之前。将进行一个插入改动的操作(interpolation),每一个head将会输出一个改动因子gtgt[0,1]<script type="math/tex" id="MathJax-Element-109">g_t且g_t∈[0,1]</script>,该值用来混合上一个时刻的wt1<script type="math/tex" id="MathJax-Element-110">w_{t-1}</script>和当前时刻由内容寻址模块产生的wct<script type="math/tex" id="MathJax-Element-111">w_t^c</script>,最后产生门限权值wgt<script type="math/tex" id="MathJax-Element-112">w_t^g</script>:
wgt=gtwct+(1gt)wt1<script type="math/tex" id="MathJax-Element-113">w_t^g=g_tw_t^c+(1-g_t)w_{t-1}</script>
显然,gt<script type="math/tex" id="MathJax-Element-114">g_t</script>的大小决定了wct<script type="math/tex" id="MathJax-Element-115">w_t^c</script>所占的分量,gt<script type="math/tex" id="MathJax-Element-116">g_t</script>越大,系统就越倾向于使用Content-base Addressing。当gt=1<script type="math/tex" id="MathJax-Element-117">g_t=1</script>时。将全然依照Content-base方式进行寻址。
在上述的interpolation操作结束后,每一个head将会产生一个长度为N的移位权值向量st<script type="math/tex" id="MathJax-Element-118">s_t</script>,st<script type="math/tex" id="MathJax-Element-119">s_t</script>是定义在全部可能的整形移位上的一个归一化分布。比如。假设移位的范围在-1到1之间(即最多能够前后移动一个位置),则移位值将有3种可能:-1,0,1,相应这3个值。st<script type="math/tex" id="MathJax-Element-120">s_t</script>也将有3个权值。那该怎么求出这些权值呢?比較常见的做法是,把这个问题看做一个多分类问题。在Controller中使用一个softmax层来产生相应位移的权重值。在论文中还实验了一种方法:在Controller中产生一个缩放因子,该因子为移位位置上均匀分布的下界。比方,假设该缩放因子值为6.7。那么st(6)=0.3,st(7)=0.7<script type="math/tex" id="MathJax-Element-121">s_t(6)=0.3,s_t(7)=0.7</script>。st<script type="math/tex" id="MathJax-Element-122">s_t</script>的其余分量为0<script type="math/tex" id="MathJax-Element-123">0</script>(仅仅取整数索引)。
这里写图片描写叙述
st<script type="math/tex" id="MathJax-Element-124">s_t</script>生成之后,接下来就要使用st<script type="math/tex" id="MathJax-Element-125">s_t</script>对wgt<script type="math/tex" id="MathJax-Element-126">w_t^g</script>进行循环卷积操作。详细例如以下式:
w˜t(i)=j=0N1wgt(j)st(ij)<script type="math/tex" id="MathJax-Element-127">\widetilde{w}_t(i)=\sum\limits_{j=0}^{N-1}w_t^g(j)s_t(i-j)</script>
写成矩阵的形式例如以下:
这里写图片描写叙述
原始可改写为:w˜t=Stwgt<script type="math/tex" id="MathJax-Element-128">\widetilde{\textbf{w}}_t=\textbf{S}_t \textbf{w}_t^g</script>
由于卷积操作会使权值的分布趋于均匀化,这将导致本来集中于单个位置的焦点出现发散现象。为了解决问题,还须要对结果进行锐化操作。详细做法是Head产生一个因子γt1<script type="math/tex" id="MathJax-Element-129">\gamma_t≥1</script>,并通过例如以下操作来进行锐化:
wt(i)=w˜t(i)γtjw˜t(j)γt<script type="math/tex" id="MathJax-Element-130">w_t(i)=\frac{\widetilde{w}_t(i)^{\gamma_t}}{\sum_j \widetilde{w}_t(j)^{\gamma_t}}</script>
通过上述操作后,权值分布将变得“尖锐”。
我们通过一个简单的样例来说明:
假设N=5,当前焦点为1,三个位置-1,0,1相应的权值为0.1,0.8,0.1,wgt=0.060.10.650.150.04<script type="math/tex" id="MathJax-Element-131">\textbf{w}_t^g=\left[ \begin{array}{ccc} 0.06\\ 0.1\\ 0.65\\ 0.15\\ 0.04\\ \end{array} \right]</script>则
St=st(0)st(1)st(2)st(3)st(4)st(4)st(0)st(1)st(2)st(3)st(3)st(4)st(0)st(1)st(2)st(2)st(3)st(4)st(0)st(1)st(1)st(2)st(3)st(4)st(0)=0.10.80.10000.10.80.10000.10.80.10.1000.10.80.80.1000.1<script type="math/tex" id="MathJax-Element-132">\textbf{S}_t = \left[ \begin{array}{ccc} s_t(0) &s_t(4) & s_t(3) & s_t(2) & s_t(1)\\ s_t(1) &s_t(0) & s_t(4) & s_t(3) & s_t(2)\\ s_t(2) &s_t(1) & s_t(0) & s_t(4) & s_t(3)\\ s_t(3) &s_t(2) & s_t(1) & s_t(0) & s_t(4)\\ s_t(4) &s_t(3) & s_t(2) & s_t(1) & s_t(0)\\ \end{array} \right]=\left[ \begin{array}{ccc} 0.1&0 & 0 & 0.1 &0.8\\ 0.8 &0.1 & 0 & 0 & 0.1\\ 0.1 &0.8 & 0.1 & 0 & 0\\ 0 &0.1 & 0.8 & 0.1 & 0\\ 0 &0 & 0.1 & 0.8 & 0.1\\ \end{array} \right] </script>
所以有:
w˜t=Stwgt=0.10.80.10000.10.80.10000.10.80.10.1000.10.80.80.1000.1×0.060.10.650.150.04=0.0530.0620.1510.5450.189<script type="math/tex" id="MathJax-Element-133">\widetilde{\textbf{w}}_t=\textbf{S}_t \textbf{w}_t^g= \left[ \begin{array}{ccc} 0.1&0 & 0 & 0.1 &0.8\\ 0.8 &0.1 & 0 & 0 & 0.1\\ 0.1 &0.8 & 0.1 & 0 & 0\\ 0 &0.1 & 0.8 & 0.1 & 0\\ 0 &0 & 0.1 & 0.8 & 0.1\\ \end{array} \right] \times \left[ \begin{array}{ccc} 0.06\\ 0.1\\ 0.65\\ 0.15\\ 0.04\\ \end{array} \right]=\left[ \begin{array}{ccc} 0.053\\ 0.062\\ 0.151\\ 0.545\\ 0.189\\ \end{array} \right]</script>
γt=2<script type="math/tex" id="MathJax-Element-134">\gamma_t=2</script>,
wt=w˜γttjw˜t(j)γt=0.00780.01060.06300.82010.0986<script type="math/tex" id="MathJax-Element-135">\textbf{w}_t=\frac{\widetilde{\textbf{w}}_t^{\gamma_t}}{\sum_j \widetilde{w}_t(j)^{\gamma_t}}=\left[ \begin{array}{ccc} 0.0078\\ 0.0106\\ 0.0630\\ 0.8201\\ 0.0986\\ \end{array} \right]</script>
能够看出来,经过锐化处理后wt<script type="math/tex" id="MathJax-Element-136">\textbf{w}_t</script>不同元素直接的差异变得更明显了(即变得“尖锐”了)。内存操作焦点将更加突出。

整个寻址的步骤例如以下图:
这里写图片描写叙述

通过上图的内存寻址系统。我们能够实现三种方式的内存訪问:
1.直接通过内容寻址,即前边提到的Content-base方式;
2.通过对Content-base产生的权值进行选择和移位而产生新的寻址权值,在这样的模式下,运行内存操作焦点跳跃到基于Content-base的下一个位置,这将使操作Head能够读取位于一系列相邻的内存块中的数据;
3.仅仅通过上一时刻的权值来生成新的操作地址权值,而不依赖不论什么当前的Content-base值。这将同意Head进行顺序迭代读取(比方能够通过多个时刻的连续迭代,读取内存中一个连续的数组)

3.3 控制器网络(Controller Network)

NTM的结构中存在非常多的自由參数。比方内存的大小。读写头的数目,内存读取时的位移的范围。可是最重要的部分还是控制器的神经网络结构。比方,是选用递归网络还是前馈网络。假设选用LSTM,则其自有的隐层状态能够作为内存矩阵Mt<script type="math/tex" id="MathJax-Element-77">M_t</script>的补充。

假设把Controller与传统计算机的CPU进行类比,则Mt<script type="math/tex" id="MathJax-Element-78">M_t</script>就相当于传统计算机的内存(RAM),递归网络中的隐层状态就相当于寄存器(Registers),同意Controller混合跨越多个时间步的信息。还有一方面,前馈网络能够通过在不同的时刻读取内存中同样的位置来模拟递归网络的特性。此外。基于前馈网络的Controller网络操作更为透明。由于此时的读写操作相比RNN的内部状态更easy解释。当然前馈网络的局限性主要在于同一时候存在的读写头数目有限。

单一的Read Head每一个时间步仅仅能操作一个内存向量,而递归Controller则可通过内部存储器同一时候读取多个内存向量。

<script type="text/javascript"> $(function () { $('pre.prettyprint code').each(function () { var lines = $(this).text().split('\n').length; var $numbering = $('
    ').addClass('pre-numbering').hide(); $(this).addClass('has-numbering').parent().append($numbering); for (i = 1; i <= lines; i++) { $numbering.append($('
  • ').text(i)); }; $numbering.fadeIn(1700); }); }); </script>
---

相关课程

  1. Python核心技术与实战
    ‍ 景霄 | 从工程角度深入理解Python
posted on 2017-07-27 09:33  mthoutai  阅读(2234)  评论(0)    收藏  举报