caffe学习(3):SVHN on caffe

caffe学习(3):SVHN on caffe

这几天因为做实验和学习Tornado的缘故,一直没时间把上次没完成的工作做完,今天补上。今天提供The Street View House Numbers即SVHN数据集在caffe上训练的过程。

一.数据准备

SVHN是一个真实世界的街道门牌号数字识别数据集.The Street View House Numbers (SVHN) Dataset,我们可以从这里下载数据,为方便转换,我们下载train_32x32.mat和test_32x32.mat,.mat文件中包含两个变量,X是一个4D的矩阵,维度是(32,32,3,n),n是数据个数,y是label变量,接下来我们先使用一段script看一下前十张图:

[python] view plain copy
 
  1. import scipy.io as sio  
  2. import matplotlib.pyplot as plt  
  3.   
  4. print 'Loading Matlab data.'  
  5. mat=sio.loadmat('train_32x32.mat')  
  6. data=mat['X']  
  7. label=mat['y']  
  8. for i in range(10):  
  9.     plt.subplot(2,5,i+1)  
  10.     plt.title(label[i][0])  
  11.     plt.imshow(data[...,i])  
  12.     plt.axis('off')  
  13. plt.show()  

可以看出,.mat文件中的数字是已经被crop出来的单个数字,接下来使用另一个script将其转换为lmdb数据:

[python] view plain copy
 
  1. import numpy as np  
  2. import caffe  
  3. import lmdb  
  4. import scipy.io as sio  
  5. import random  
  6. from caffe.proto import caffe_pb2  
  7.   
  8. def main():  
  9.     train=sio.loadmat('train_32x32.mat')  
  10.     test=sio.loadmat('test_32x32.mat')  
  11.   
  12.     train_data=train['X']  
  13.     train_label=train['y']  
  14.     test_data=test['X']  
  15.     test_label=test['y']  
  16.   
  17.     train_data = np.swapaxes(train_data, 0, 3)  
  18.     train_data = np.swapaxes(train_data, 1, 2)  
  19.     train_data = np.swapaxes(train_data, 2, 3)  
  20.   
  21.     test_data = np.swapaxes(test_data, 0, 3)  
  22.     test_data = np.swapaxes(test_data, 1, 2)  
  23.     test_data = np.swapaxes(test_data, 2, 3)  
  24.   
  25.     N=train_label.shape[0]  
  26.     map_size=train_data.nbytes*10  
  27.     env=lmdb.open('svhn_train_lmdb',map_size=map_size)  
  28.     txn=env.begin(write=True)  
  29.   
  30. #shuffle the training data  
  31.     r=list(range(N))  
  32.     random.shuffle(r)  
  33.   
  34.     count=0  
  35.     for i in r:  
  36. datum=caffe_pb2.Datum()  
  37.         label=int(train_label[i][0])  
  38. if label==10:  
  39.             label=0  
  40.         datum=caffe.io.array_to_datum(train_data[i],label)  
  41.         str_id='{:08}'.format(count)  
  42.         txn.put(str_id,datum.SerializeToString())  
  43.   
  44.         count += 1  
  45.         if count % 1000 == 0:  
  46. print('already handled with {} pictures'.format(count))  
  47.             txn.commit()  
  48.             txn = env.begin(write=True)  
  49.   
  50.     txn.commit()  
  51.     env.close()  
  52.   
  53.     map_size = test_data.nbytes * 10  
  54.     env = lmdb.open('svhn_test_lmdb', map_size=map_size)  
  55.     txn = env.begin(write=True)  
  56.     count = 0  
  57.     for i in range(test_label.shape[0]):  
  58. datum = caffe_pb2.Datum()  
  59.         label = int(test_label[i][0])  
  60. if label == 10:  
  61.             label = 0  
  62.         datum = caffe.io.array_to_datum(test_data[i], label)  
  63.         str_id = '{:08}'.format(count)  
  64.         txn.put(str_id, datum.SerializeToString())  
  65.   
  66.         count += 1  
  67.         if count % 1000 == 0:  
  68. print('already handled with {} pictures'.format(count))  
  69.             txn.commit()  
  70.             txn = env.begin(write=True)  
  71.   
  72.     txn.commit()  
  73.     env.close()  
  74.   
  75. if __name__=='__main__':  
  76.     main()  

这样就可以得到svhn_train_lmdb和svhn_test_lmdb了

二.Data Pre-processing

SVHN比较简单,我们不做任何data augmentation操作,只通过上篇文章的script计算出其图像均值:

[python] view plain copy
 
  1. import caffe  
  2. import lmdb  
  3. import numpy as np  
  4. from caffe.proto import caffe_pb2  
  5. import time  
  6.   
  7. lmdb_env=lmdb.open('svhn_train_lmdb')  
  8. lmdb_txn=lmdb_env.begin()  
  9. lmdb_cursor=lmdb_txn.cursor()  
  10. datum=caffe_pb2.Datum()  
  11.   
  12. N=0  
  13. mean = np.zeros((1, 3, 32, 32))  
  14. beginTime = time.time()  
  15. for key,value in lmdb_cursor:  
  16.     datum.ParseFromString(value)  
  17.     data=caffe.io.datum_to_array(datum)  
  18.     image=data.transpose(1,2,0)  
  19.     mean[0,0] += image[:, :, 0]  
  20.     mean[0,1] += image[:, :, 1]  
  21.     mean[0,2] += image[:, :, 2]  
  22.     N+=1  
  23.     if N % 1000 == 0:  
  24.         elapsed = time.time() - beginTime  
  25. print("Processed {} images in {:.2f} seconds. "  
  26.               "{:.2f} images/second.".format(N, elapsed,  
  27.                                              N / elapsed))  
  28. mean[0]/=N  
  29. blob = caffe.io.array_to_blobproto(mean)  
  30. with open('mean.binaryproto', 'wb') as f:  
  31.     f.write(blob.SerializeToString())  
  32.   
  33. lmdb_env.close()  

三.实验

这里我们采用caffe自带的cifar_full模型进行训练:

[plain] view plain copy
 
  1. caffe train -solver=solver.prototxt -gpu 0  

最后得到的model的准确率为94.03%,效果还是很好的

四.总结

经过这两篇文章,可以看出,对于一般的数据集,如果要在caffe中训练的话,一般有以下几步:

1.data->lmdb:将数据转换为lmdb数据,其实caffe也支持很多其他格式的输入,如IMAGEDATA,HDF5DATA,但经过实验,这些数据消耗的大量io操作会大大加剧训练的时间

2.data augmentation:常见的几种数据加强方式均在上文cifar100中有所阐释

3.data pre-processing:对于图像数据来说,最常见的数据预处理就是减去图像的均值

4.model designing:最后一步自然是设计模型,进行训练了

到这里对caffe训练过程已经是非常熟悉了,下一步让我们深入一点,看一下caffe的源码结构和实现细节,敬请期待!

PS:文中的Script和训练配置文件均在github上:

posted @ 2017-07-06 15:33  菜鸡一枚  阅读(406)  评论(0)    收藏  举报