天池街景字符识别总结

天池街景字符识别总结

主要思想是通过检测的方法识别图片中字符,检测模型用的是CenterNet目标检测模型,骨干网络主要使用的是resnet-50,和dlav0_34

@mhxin 同学关于数据增强和整个实现的流程讲的很好,特别是一些数据增强的方法和比赛的技巧,强烈建议大家看看

@xingyizhou 这位是CenterNet作者,源代码和论文都可以在github上找到

代码主要是CenterNet,只要将数据格式转成COCO格式或者VOC格式就可以,感兴趣的同学可以自己看一下原始代码和论文,这里我主要介绍一下我所用的几个trick。

数据增强

以下是在训练街景字符识别过程中所采用的数据增强方法,分图像颜色变换和图像尺度变换,图像尺度变换用的是仿射变换,保持原始图像比例,天池的街景数据中训练集的数据有30000张,测试集有10000张,总体来说数据偏少,这类问题做好数据增强很重要,可以在测试模型的时候发现模型对哪一类图片识别效果不好,再针对性的去做数据增强,后期80%的时间花在数据增强上。下面是我采用的几种数据增强方法,实际使用过程中是按一定概率组合使用的

  • 旋转
  • 平移
  • 翻转
  • 随机裁剪
  • 随机遮挡
  • 随机区域亮度变化
  • 图像饱和度增强
  • 图像亮度和对比度增强

数据增强用两种方式离线数据增强在线数据增强,离线数据增强是值先增强训练数据扩大训练数据集,再进行训练,在线增强指在训练生成数据时按一定概率对数据进行数据增强,这样每一个batch的数据都不一样,需要注意的是最好设置固定的随机种子。两个方法各有利弊,如果训练资源充足的情况下建议使用在线数据增强,省去了大量数据集反复操作的麻烦

模型结构

enter image description here

网络输出两个head(也可以再加一个offset),其中hm有10个通道分别预测0~9在图像中的位置,wh预测目标框的宽和高。骨干网络可以用resnet系列,如果要跑在移动端上可以使用mobilenet等轻量化网络,以前搭过ctdet-mobilenet-v2,在移动端做目标检测效果还可以。

为什么用CenterNet来做字符检测?

  • anchor free 的目标检测模型,要调的参数较少
  • point base 对小物体检测较好
  • 可以用faster rcnn做

为什么不用OCR的方式做?

  • 字符类别少,且排布比较规则
  • 训练数据集较少,不适合OCR类型的训练
  • 试了一下CRNN效果不是很好
  • 自己对检测熟悉一点

一些技巧

下面是各个方面的一些提分技巧和训练技巧,比较零碎,也是类似竞赛用的比较多的一些方法。

训练策略

使用cycle learning rate策略训练,因为训练集总共只有40000张,还是比较少的,很容易出现过拟合,使用这种学习率的策略能够很好的避免过拟合,之前支付宝一个团队拿下kaggle比赛的冠军用的就是这个策略,实现起来也不难。

enter image description here

TTA(test time augmentation)

在跑测试集结果的时候,可以使用tta,对测试图片n次随机数据增强,n*tta投票取最终结果,这种方法能够提高0.2左右,但是很耗时,TTA是个非常重要的技巧,方式很多,实现起来也不难

FP16训练加速

对于跟我一样的一些显卡资源比较少的小伙伴,没有多块卡,训练调参会很慢,这里推荐大家使用NVIDIA的Apex库,Apex支持一下几种模式:

  • O0:纯FP32训练,可以作为accuracy的baseline
  • O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算
  • O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算
  • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline

我们这里只是需要减少内存消耗,选择O1,基本上不会对进度造成影响,同时内存消耗至少减半,可以较大batch_size训练(只支持2080TI,或者V100,P100),关于混合精度训练推荐看浅谈混合精度训练这篇博客,这里不做详细展开了,只需要知道这种方法能够大大减少训练内存就行了,我在代码里面也增加这个opt。

模型集成

  • 可以使用同一个模型的不同epoch来测试结果,综合取最终结果
  • 训练不同模型如resnet-18和dla-34等多个模型,综合测试结果

requirements

apex 0.9.10.dev0 h5py 2.10.0 opencv-python 4.2.0.34 pandas 1.0.5 Pillow 7.1.2 progress 1.5 pycocotools 2.0.1 torch 1.5.1 torchvision 0.6.1 tqdm 4.46.1 urllib3 1.25.9

测试效果

enter image description here

enter image description here
enter image description here
enter image description here

总结

上面基本介绍了我用的一些技巧,我的观点是选一个好的检测模型,基本能达到0.92+的结果,后续就要在数据增强和一些技巧上下功夫了,大家用的模型都差不多,调参提升的空间不是很大,推荐使用faster Rcnn或者CenterNet,faster Rcnn在训练和调参上可能稍微麻烦一些。之前做过一些移动端目标检测的项目,如果有这方面的需求可以联系我。最后感谢感谢天池官方组织这次比赛,群里大神们讨论也都很积极,祝大赛越办越好。

posted @ 2020-09-13 23:01  别再闹了  阅读(654)  评论(0)    收藏  举报