Deformable ConvNets--Part2: Spatial Transfomer Networks(STN)

转自:https://blog.csdn.net/u011974639/article/details/79681455

Deformable ConvNet简介

关于Deformable Convolutional Networks的论文解读,共分为5个部分,本章是第二部分:

  • [ ] Part1: 快速学习实现仿射变换
  • [x] Part2: Spatial Transfomer Networks论文解读
  • [ ] Part3: TenosorFlow实现STN
  • [ ] Part4: Deformable Convolutional Networks论文解读
  • [ ] Part5: TensorFlow实现Deformable ConvNets

本章解读Spatial Transfomer Networks论文,看STN是如何将仿射变换加入到现有CNN架构上,并使之支持端对端训练。

Spatial Transformer Networks

STN:Spatial Transformer Networks

收录:NIPS 2015(Advances in Neural Information Processing Systems 28)

原文地址:STN

代码:



Motivation

数据问题

在实际场景,不同的场景下目标会存在不同状态,如下图:

这里写图片描述

对于相同目标物,不一样的尺度缩放、形变、背景干扰、观察视角,获取到的目标信息都是不同的,这极大的增加了目标检测等任务的难度。

针对这一问题,大多数先进系统的解决办法是,对模型的训练数据做数据增强,例如随机的crop(裁剪)、平移、放缩、旋转等操作,通过增加数据量,进而增加模型的泛化能力。

池化存在的问题

先引Hinton对池化操作的评价:

The pooling operation used in convolutional neural networks is a big mistake and the fact that it works so well is a disaster. (Geoffrey Hinton, Reddit AMA)

池化层在接收复杂输入的基础上,将复杂输入分成若干cell,提取成简单的cell,目标简单的旋转或平移,经过几层的池化提取的信息就比较相似了,通过不断简化输入聚合像素值,模型内部重复的池化操作,可让模型对目标有内在不变性

这里写图片描述

关于池化层存在的问题,在前面的blog里面已经讲了很多遍。池化操作是破坏性的,丢弃了75%的输入信息,虽然模型增加了健壮性但也丢弃了细节信息。

除了这一问题,池化层受限于自身架构上的限制,只能获取局部信息和固定的变换结构,这只有很小的感受野,通过不断加深网络层次,获得更大的感受野。我们不能随意的调整池化层的大小,因为这会急剧的降低特征映射的空间分辨率。



Spatial Transformer Networks

针对上面提到的问题,DeepMind的STN:Spatial Transformer Networks工作开创性的在CNN中引入通过加入空间变换层达到仿射变换目的。STN模型优点在于:

  • 模块化:STN可以很方便的集成到现存的CNN架构中
  • 可微分:STN可通过BP优化参数,支持end-2-end训练
  • 动态性:STN相比于池化等操作,可动态的将仿射变换应用于采样点上

仿射变换示意图:

这里写图片描述

  • 图(a)是常规的恒等变换,在输入U上采样网格,得到输出V
  • 图(b)是采样是做了仿射变换,输入U上采样使用的仿射变换后的网格,最终得到U

由Part1可知,当Tθ是2D方式变换,坐标变换为:

(xisyis)=Tθ(Gi)=Aθ(xityit1)=[θ11θ12θ13θ21θ22θ23](xityit1)

(xis,yis)是源特征映射(Source)中采样点,(xit,yit)是仿射变换目标输出。Aθ是仿射变换矩阵。注意到上式是从目标通过变换矩阵就计算出源输入。

仿射变换层

STN中核心的Spatial Transformer变换如下:

这里写图片描述

整个Spatial Transformer分为以下几个部分:

  • Localisation net,用于获取仿射矩阵,即各个θ值。
  • Grid generator,用于生成采样坐标
  • Sampler,实际对U的采样

注意一下:在Part1中,讲解使用双线性插值实现仿射变换,这是仿射变换层的核心基础。

Localisation net

localisation network的目标是在输入特征映射上应用卷积或FC层,获取到仿射变换矩阵参数θ,结构如下:

  • input: 特征映射U,shape为(H,W,C)
  • output: 仿射矩阵参数θ,shape为(6,)
  • architecture: 全连接层或卷积层

Localisation net在实际训练过程中不断学习变换的参数。

Grid generator

Grid generator干的事就是输出采样网格,即在输入中采样的点生成期望的转换输出。

普通的变换:

这里写图片描述

输入和输出的采样点是相同的。

仿射变换:

这里写图片描述

我们想通过目标采样网格经过仿射变换获取到实际在输入上采样网格点。

具体来讲,Grid generator和Part1中讲的使用双线性插值实现仿射变换过程类似:

  • 首先,创建和输入U等同空间大小(H,W)的棋盘网格((xt,yt)覆盖了所有输入的点,这表示了输出特征中的目标坐标点)
  • 因为我们要在上述创建的棋盘网格做仿射变换,故对上述网格坐标做向量化处理,即[xt yt 1]形式
  • 将仿射矩阵的参数reshape成2×3的矩阵形式,并使用下面计算得到我们期望的采样点

(xisyis)=Tθ(Gi)=Aθ(xityit1)=[θ11θ12θ13θ21θ22θ23](xityit1)

注意到这里的仿射变换:

这里写图片描述

如图所示:实际上是通过目标网格找到实际采样网格~

[xs ys]表示我们在输入特征上应该采样的点。想想,上一节我们讲过,如果采样的点是分数该怎么办,是的,使用双线性插值来搞定~

Sampler

注意到,仿射矩阵参数是可学习的,这要求采样是可微分的,实际中用双线性插值实现仿射变换,刚好双线性插值是可微分的,这刚好符合要求

任何采样核:

对于(xis,yis)对应的仿射变换Tθ(G)在输入U上应用采样核得到输出V:

Vci=nHmWUnmck(xism;Φx)k(yisn;Φy)  i[1..HW]  c[1..C]

这里写图片描述

其中Φx,Φy是采样核k()的参数,Unmc在通道c上(n,m)坐标的值。Vic是像素i的输出值,具体指的是通道c中位置(xit,yit)。注意这样的采样应用于每个通道上。

integer sampling

理论上,任何采样核都可以使用,只要可微分即可,即可通过xis,yis定义出梯度,例如使用integer sampling kernel:

Vci=nHmWUnmcδ(xis+0.5m)δ(xis+0.5n)

其中x+0.5 表示x最近的整数值,δ()是Kronecker delta function。这样采样核等同于复制(xis,yis)最近的像素值,直接输出为(xit,yit)

双线性采样

对比双线性插值:

f(P)=f(Q11)(x2x)(y2y)+f(Q21)(xx1)(y2y)+f(Q12)(x2x)(yy1)+f(Q22)(xx1)(yy1)

在STN中使用双线性采样核的表达式总结如下:

Vci=nHmWUnmcmax(0,1|xism|)max(0,1|yisn|)

这个表达式和双线性插值是一样的。 同样是遍历周围的四个点(上述公式是遍历所有点,实际上是处理点周围的点),分母省去了,将一堆旁边项目使用max函数代替了。

双线性插值的反向传播

这样的双线性插值函数是支持反向传播,我们看看上述双线性采样的偏导数:

(1)VicUnmc=nHmWmax(0,1|xism|)max(0,1|yisn|)

(2)Vicxis=nHmWUnmcmax(0,1|yisn|)={0if |mxis|11if mxis1if m<xis

同理Vicyis类似。

双线性插值支持微分操作,允许梯度流回传到feature map上(2),并且回传到采样坐标上(1),这样就能调整学习到的变换参数θ,因为localisation netxisθyisθ可帮助修正矩阵。示意图如下:

这里写图片描述

双线性采样实现了整个BP,注意上述的公式是迭代所有输入位置,实际操作是指看周围的像素点。

在输入U上双线性采样得到shape为(H,W,C)的输出V,这意味着我们可以指定shape大小,达到上采样或下采样的目的。我们的设计不受限于双线性采样,其他的采样核也可以使用,重要的是支持可微分,这样才能反向传播训练localisation net

空间变换案例

Distorted MNIST

在Disturted MNIST上使用空间变换得到的结果:

这里写图片描述

可以看到这是如何精准的学习到健壮的分类模型,通过放缩和消除背景影响,定位的关键信息,再做标准化操作。

German Traffic Sign Recognition Benchmark (GTSRB) dataset

这里写图片描述

这里写图片描述

可以看到空间变换集中于关键信息上,移除了背景信息。

参考资料

Spatial Transformer Networks


常见的仿射变换

  • 平移变换:

    [10θ1301θ23](xityit1)=(xit+θ13yit+θ23)

  • 缩放变换:

    [θ11000θ120](xityit1)=(θ11xitθ12yit)

  • 旋转操作:

    [cos(α)sin(α)0sin(α)cos(α)0](xityit1)=(cos(α)xit+sin(α)yitsin(α)xit+cos(α)yit)

posted @ 2019-04-12 16:42  Le1B_o  阅读(400)  评论(0编辑  收藏  举报