梳理caffe代码data_transformer(十二)

梳理caffe代码data_transformer(十二)

data_transformer详细注释看头文件和实现部分:

头文件:

 

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. /////////////////TransformationParameter的caffe消息定义  
  2. /* 
  3. // Message that stores parameters used to apply transformation 
  4. // to the data layer's data 
  5. message TransformationParameter { 
  6.   // For data pre-processing, we can do simple scaling and subtracting the 
  7.   // data mean, if provided. Note that the mean subtraction is always carried 
  8.   // out before scaling. 
  9.   optional float scale = 1 [default = 1]; 
  10.   // Specify if we want to randomly mirror data. 
  11.   optional bool mirror = 2 [default = false]; 
  12.   // Specify if we would like to randomly crop an image. 
  13.   optional uint32 crop_size = 3 [default = 0]; 
  14.   // mean_file and mean_value cannot be specified at the same time 
  15.   optional string mean_file = 4; 
  16.   // if specified can be repeated once (would substract it from all the channels) 
  17.   // or can be repeated the same number of times as channels 
  18.   // (would subtract them from the corresponding channel) 
  19.   repeated float mean_value = 5; 
  20.   // Force the decoded image to have 3 color channels. 
  21.   optional bool force_color = 6 [default = false]; 
  22.   // Force the decoded image to have 1 color channels. 
  23.   optional bool force_gray = 7 [default = false]; 
  24. */  
  25. /* 
  26. DataTransformer类主要负责对数据进行预处理, 比如减去均值、进行crop,镜像,强制设置为彩色强制设置为灰度图像以及像素值的缩放,此外该类还将Datum、const vector<Datum>、cv::Mat&、vector<cv::Mat> 、Blob<Dtype>*类型的数据变换到目标大小的blob。负责对上述类型的数据推断其shape。 
  27. */  
  28. #ifndef CAFFE_DATA_TRANSFORMER_HPP  
  29. #define CAFFE_DATA_TRANSFORMER_HPP  
  30. #include <vector>  
  31. #include "caffe/blob.hpp"  
  32. #include "caffe/common.hpp"  
  33. #include "caffe/proto/caffe.pb.h"  
  34.   
  35. namespace caffe {  
  36.   
  37. /** 
  38.  * @brief Applies common transformations to the input data, such as 
  39.  * scaling, mirroring, substracting the image mean... 
  40.  */  
  41. template <typename Dtype>  
  42. class DataTransformer {  
  43.  public:  
  44.   explicit DataTransformer(const TransformationParameter& param, Phase phase);  
  45.   virtual ~DataTransformer() {}  
  46.   
  47.   /** 
  48.    * @brief Initialize the Random number generations if needed by the 
  49.    *    transformation. 
  50.    */  
  51. // 初始化随机数生成器,因为在对数据进行变换的时候有可能用到,比如说打乱数据的输入顺序  
  52.   void InitRand();  
  53.   
  54.   /** 
  55.    * @brief Applies the transformation defined in the data layer's 
  56.    * transform_param block to the data. 
  57.    * 
  58.    * @param datum 
  59.    *    Datum containing the data to be transformed. 
  60.    * @param transformed_blob 
  61.    *    This is destination blob. It can be part of top blob's data if 
  62.    *    set_cpu_data() is used. See data_layer.cpp for an example. 
  63.    */  
  64. // 对Datum的数据进行变换,放入到transformed_blob中  
  65.   void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);  
  66.   
  67.   /** 
  68.    * @brief Applies the transformation defined in the data layer's 
  69.    * transform_param block to a vector of Datum. 
  70.    * 
  71.    * @param datum_vector 
  72.    *    A vector of Datum containing the data to be transformed. 
  73.    * @param transformed_blob 
  74.    *    This is destination blob. It can be part of top blob's data if 
  75.    *    set_cpu_data() is used. See memory_layer.cpp for an example. 
  76.    */  
  77. // 对Datum容器的数据进行变换翻入到transformed_blob  
  78.   void Transform(const vector<Datum> & datum_vector,  
  79.                 Blob<Dtype>* transformed_blob);  
  80.   
  81. #ifdef USE_OPENCV  
  82.   /** 
  83.    * @brief Applies the transformation defined in the data layer's 
  84.    * transform_param block to a vector of Mat. 
  85.    * 
  86.    * @param mat_vector 
  87.    *    A vector of Mat containing the data to be transformed. 
  88.    * @param transformed_blob 
  89.    *    This is destination blob. It can be part of top blob's data if 
  90.    *    set_cpu_data() is used. See memory_layer.cpp for an example. 
  91.    */  
  92. // 如果定义OpenCV还可能对mat容器数据类型的数据进行变换  
  93.   void Transform(const vector<cv::Mat> & mat_vector,  
  94.                 Blob<Dtype>* transformed_blob);  
  95.   
  96.   /** 
  97.    * @brief Applies the transformation defined in the data layer's 
  98.    * transform_param block to a cv::Mat 
  99.    * 
  100.    * @param cv_img 
  101.    *    cv::Mat containing the data to be transformed. 
  102.    * @param transformed_blob 
  103.    *    This is destination blob. It can be part of top blob's data if 
  104.    *    set_cpu_data() is used. See image_data_layer.cpp for an example. 
  105.    */  
  106. // 将opencv读取的单个图像转换到blob中去  
  107.   void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);  
  108. #endif  // USE_OPENCV  
  109.   
  110.   /** 
  111.    * @brief Applies the same transformation defined in the data layer's 
  112.    * transform_param block to all the num images in a input_blob. 
  113.    * 
  114.    * @param input_blob 
  115.    *    A Blob containing the data to be transformed. It applies the same 
  116.    *    transformation to all the num images in the blob. 
  117.    * @param transformed_blob 
  118.    *    This is destination blob, it will contain as many images as the 
  119.    *    input blob. It can be part of top blob's data. 
  120.    */  
  121. // 将输入的blob进行变换,可能是取出blob的中的一部分数据到新的blob  
  122.   void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);  
  123.   
  124.   /** 
  125.    * @brief Infers the shape of transformed_blob will have when 
  126.    *    the transformation is applied to the data. 
  127.    * 
  128.    * @param datum 
  129.    *    Datum containing the data to be transformed. 
  130.    */  
  131. // 根据Datum获取blob的形状  
  132.   vector<int> InferBlobShape(const Datum& datum);  
  133.   /** 
  134.    * @brief Infers the shape of transformed_blob will have when 
  135.    *    the transformation is applied to the data. 
  136.    *    It uses the first element to infer the shape of the blob. 
  137.    * 
  138.    * @param datum_vector 
  139.    *    A vector of Datum containing the data to be transformed. 
  140.    */  
  141. // 根据Datum容器获取blob的形状  
  142.   vector<int> InferBlobShape(const vector<Datum> & datum_vector);  
  143.   /** 
  144.    * @brief Infers the shape of transformed_blob will have when 
  145.    *    the transformation is applied to the data. 
  146.    *    It uses the first element to infer the shape of the blob. 
  147.    * 
  148.    * @param mat_vector 
  149.    *    A vector of Mat containing the data to be transformed. 
  150.    */  
  151. #ifdef USE_OPENCV  
  152. // 根据Mat容器获取blob的形状  
  153.   vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);  
  154.   /** 
  155.    * @brief Infers the shape of transformed_blob will have when 
  156.    *    the transformation is applied to the data. 
  157.    * 
  158.    * @param cv_img 
  159.    *    cv::Mat containing the data to be transformed. 
  160.    */  
  161. // 根据Mat获取blob的形状  
  162.   vector<int> InferBlobShape(const cv::Mat& cv_img);  
  163. #endif  // USE_OPENCV  
  164.   
  165.  protected:  
  166.    /** 
  167.    * @brief Generates a random integer from Uniform({0, 1, ..., n-1}). 
  168.    * 
  169.    * @param n 
  170.    *    The upperbound (exclusive) value of the random number. 
  171.    * @return 
  172.    *    A uniformly random integer value from ({0, 1, ..., n-1}). 
  173.    */  
  174. // 生成从0到n-1的服从均匀分布的随机数,要求继承他的都必须实现如何生成随机数  
  175.   virtual int Rand(int n);  
  176. // 将给定的Datum进行转换  
  177.   void Transform(const Datum& datum, Dtype* transformed_data);  
  178.   
  179.   // 变换所使用的参数  
  180.   TransformationParameter param_;  
  181.   // 随机数生成器的种子  
  182.   shared_ptr<Caffe::RNG> rng_;  
  183.   // 是训练还是测试?  
  184.   Phase phase_;  
  185.   // 数据均值 blob  
  186.   Blob<Dtype> data_mean_;  
  187.   // 数据均值blob的容器  
  188.   vector<Dtype> mean_values_;  
  189. };  
  190.   
  191. }  // namespace caffe  
  192.   
  193. #endif  // CAFFE_DATA_TRANSFORMER_HPP_  

实现:

 

 

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. //DataTransformer需要输入的是blob,所以需要看一下里面的参数,因此再把这一部分内容的proto贴出来,这是新版的caffe  
  2. /* 
  3. // Specifies the shape (dimensions) of a Blob. 
  4. message BlobShape { 
  5.   repeated int64 dim = 1 [packed = true]; 
  6.  
  7. message BlobProto { 
  8.   optional BlobShape shape = 7; 
  9.   repeated float data = 5 [packed = true]; 
  10.   repeated float diff = 6 [packed = true]; 
  11.   repeated double double_data = 8 [packed = true]; 
  12.   repeated double double_diff = 9 [packed = true]; 
  13.  
  14.   // 4D dimensions -- deprecated.  Use "shape" instead. 
  15.   optional int32 num = 1 [default = 0]; 
  16.   optional int32 channels = 2 [default = 0]; 
  17.   optional int32 height = 3 [default = 0]; 
  18.   optional int32 width = 4 [default = 0]; 
  19. */  
  20. /////////////////TransformationParameter的caffe消息定义  
  21. /* 
  22. // Message that stores parameters used to apply transformation 
  23. // to the data layer's data 
  24. message TransformationParameter { 
  25.   // For data pre-processing, we can do simple scaling and subtracting the 
  26.   // data mean, if provided. Note that the mean subtraction is always carried 
  27.   // out before scaling. 
  28.   optional float scale = 1 [default = 1]; 
  29.   // Specify if we want to randomly mirror data. 
  30.   optional bool mirror = 2 [default = false]; 
  31.   // Specify if we would like to randomly crop an image. 
  32.   optional uint32 crop_size = 3 [default = 0]; 
  33.   // mean_file and mean_value cannot be specified at the same time 
  34.   optional string mean_file = 4; 
  35.   // if specified can be repeated once (would substract it from all the channels) 
  36.   // or can be repeated the same number of times as channels 
  37.   // (would subtract them from the corresponding channel) 
  38.   repeated float mean_value = 5; 
  39.   // Force the decoded image to have 3 color channels. 
  40.   optional bool force_color = 6 [default = false]; 
  41.   // Force the decoded image to have 1 color channels. 
  42.   optional bool force_gray = 7 [default = false]; 
  43. */  
  44. #ifdef USE_OPENCV  
  45. #include <opencv2/core/core.hpp>  
  46. #endif  // USE_OPENCV  
  47.   
  48. #include <string>  
  49. #include <vector>  
  50.   
  51. #include "caffe/data_transformer.hpp"  
  52. #include "caffe/util/io.hpp"  
  53. #include "caffe/util/math_functions.hpp"  
  54. #include "caffe/util/rng.hpp"  
  55.   
  56. namespace caffe {  
  57. // 构造函数  
  58. template<typename Dtype>  
  59. DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,  
  60.     Phase phase)  
  61.     : param_(param), phase_(phase) {  
  62.   // check if we want to use mean_file  
  63.   // 判断是否有平均值文件  
  64.   if (param_.has_mean_file()) {  
  65.     CHECK_EQ(param_.mean_value_size(), 0) <<  
  66.       "Cannot specify mean_file and mean_value at the same time";  
  67.     // 平均值文件的路径  
  68.     const string& mean_file = param.mean_file();  
  69.     if (Caffe::root_solver()) {  
  70.       LOG(INFO) << "Loading mean file from: " << mean_file;  
  71.     }  
  72.     BlobProto blob_proto;// 调用google/protobuf?? ,用于加速运算的数据接口,有时间再详细了解其应用方法   
  73. //这个函数是实现了从二进制文件中读取数据到blob_proto中,猜测函数来自第3方库的google/protobuf模块   
  74.     ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);  
  75.     data_mean_.FromProto(blob_proto);// 调用Blob类的成员函数FromRroto从BlobProto中加载数据   
  76.   }  
  77.   // check if we want to use mean_value  
  78.   if (param_.mean_value_size() > 0) {  
  79.     CHECK(param_.has_mean_file() == false) <<  
  80.       "Cannot specify mean_file and mean_value at the same time";  
  81.     for (int c = 0; c < param_.mean_value_size(); ++c) {  
  82.       mean_values_.push_back(param_.mean_value(c));//将元素param_.mean_value(c)加入到mean_values_容器的最后一位  
  83.     }  
  84.   }  
  85. }  
  86.   
  87. /*提前先描述一下数据层的Datum, 
  88. Datum数据结构,Caffe并不是把向量和矩阵直接放进数据库的,而是将数据通过caffe.proto里定义的一个datum类来封装。数据库里放的是一个个的datum序列化成的字符串。Datum的定义摘录如下: 
  89. message Datum { 
  90.   optional int32 channels = 1; 
  91.   optional int32 height = 2; 
  92.   optional int32 width = 3; 
  93.   // the actual image data, in bytes 
  94.   optional bytes data = 4; 
  95.   optional int32 label = 5; 
  96.   // Optionally, the datum could also hold float data. 
  97.   repeated float float_data = 6; 
  98.   // If true data contains an encoded image that need to be decoded 
  99.   optional bool encoded = 7 [default = false]; 
  100. 一个Datum有三个维度,channels, height,和width,可以看做是少了num维度的Blob。存放数据的地方有两个:byte_data和float_data,分别存放整数型和浮点型数据。图像数据一般是整形,放在byte_data里,特征向量一般是浮点型,放在float_data里。label存放数据的类别标签,是整数型。encoded标识数据是否需要被解码(里面有可能放的是JPEG或者PNG之类经过编码的数据)。Datum这个数据结构将数据和标签封装在一起,兼容整形和浮点型数据。经过Protobuf编译后,可以在Python和C++中都提供高效的访问。同时Protubuf还为它提供了序列化与反序列化的功能。存放进LMDB的就是Datum序列化生成的字符串。 
  101. Caffe中关于LMDB的代码有三类:生成数据集、读取数据集、生成特征向量。接下来就分别针对三者进行分析。 
  102. 生成数据集: 
  103. 生成数据集的代码在examples,随数据集提供,比如MNIST。 
  104. 首先,创建访问LMDB所需的一些变量: 
  105. MDB_env *mdb_env; 
  106. MDB_dbi mdb_dbi; 
  107. MDB_val mdb_key, mdb_data; 
  108. MDB_txn *mdb_txn; 
  109. ... 
  110. mdb_env是整个数据库环境的句柄,mdb_dbi是环境中一个数据库的句柄,mdb_key和mdb_data用来存放向数据库中输入数据的“值”。mdb_txn是数据库事物操作的句柄,”txn”是”transaction”的缩写。 
  111. 然后,创建数据库环境,创建并打开数据库: 
  112. if (db_backend == "lmdb") {  // lmdb 
  113.   LOG(INFO) << "Opening lmdb " << db_path; 
  114.   CHECK_EQ(mkdir(db_path, 0744), 0) 
  115.       << "mkdir " << db_path << "failed"; 
  116.   CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed"; 
  117.   CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB 
  118.       << "mdb_env_set_mapsize failed"; 
  119.   CHECK_EQ(mdb_env_open(mdb_env, db_path, 0, 0664), MDB_SUCCESS) 
  120.       << "mdb_env_open failed"; 
  121.   CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) 
  122.       << "mdb_txn_begin failed"; 
  123.   CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS) 
  124.       << "mdb_open failed. Does the lmdb already exist? "; 
  125. } else { 
  126.   LOG(FATAL) << "Unknown db backend " << db_backend; 
  127. mkdir(db_path, 0744)为数据库创建文件夹,如果文件夹已经存在,程序会报错退出。也就是说,程序不会覆盖已有的数据库。已有的数据库如果不要了,需要手动删除。需要注意的是,LMDB的一个环境中是可以有多个数据库的,数据库之间以名字区分。mdb_open()的第二个参数实际上就是数据库的名称(char *)。当一个环境中只有一个数据库的时候,这个参数可以给NULL。最后,为每一个图像创建Datum对象,向对象内写入数据,然后将其序列化成字符串,将字符串放入数据库中: 
  128. Datum datum; 
  129. datum.set_channels(1); 
  130. datum.set_height(rows); 
  131. datum.set_width(cols); 
  132. for (int item_id = 0; item_id < num_items; ++item_id) { 
  133.   image_file.read(pixels, rows * cols); 
  134.   label_file.read(&label, 1); 
  135.   datum.set_data(pixels, rows*cols); 
  136.   datum.set_label(label); 
  137.   snprintf(key_cstr, kMaxKeyLength, "%08d", item_id); 
  138.   datum.SerializeToString(&value); 
  139.   string keystr(key_cstr); 
  140.  
  141.   // Put in db 
  142.   if (db_backend == "lmdb") {  // lmdb 
  143.     mdb_data.mv_size = value.size(); 
  144.     mdb_data.mv_data = reinterpret_cast<void*>(&value[0]); 
  145.     mdb_key.mv_size = keystr.size(); 
  146.     mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]); 
  147.     CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS) 
  148.         << "mdb_put failed"; 
  149.   } else { 
  150.     LOG(FATAL) << "Unknown db backend " << db_backend; 
  151.   } 
  152.  
  153.   if (++count % 1000 == 0) { 
  154.     // Commit txn 
  155.     if (db_backend == "lmdb") {  // lmdb 
  156.       CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) 
  157.           << "mdb_txn_commit failed"; 
  158.       CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS) 
  159.           << "mdb_txn_begin failed"; 
  160.     } else { 
  161.       LOG(FATAL) << "Unknown db backend " << db_backend; 
  162.     } 
  163.   } 
  164. 放入数据的Key是图像的编号,前面补0至8位。MDB_val类型的mdb_data和mdb_key中存放的是数据来源的指针,以及数据的长度。mdb_put()函数将数据存入数据库。每隔1000个图像commit一次数据库。只有commit之后,数据才真正写入磁盘。 
  165. 读取数据集: 
  166. Caffe中读取LMDB数据集的代码是DataLayer,用在网络的最下层,提供数据。DataLayer采用顺序遍历的方式读取数据,不支持打乱数据顺序,只能随机跳过前若干个数据。 
  167. 首先,在DataLayer的DataLayerSetUp方法中,打开数据库,并获取迭代器cursor_: 
  168. db_.reset(db::GetDB(this->layer_param_.data_param().backend())); 
  169. db_->Open(this->layer_param_.data_param().source(), db::READ); 
  170. cursor_.reset(db_->NewCursor()); 
  171. 然后,在每一次的数据预取时,InternalThreadEntry()方法中,从数据库中读取字符串,反序列化为Datum对象,再从Datum对象中取出数据: 
  172. Datum datum; 
  173. datum.ParseFromString(cursor_->value()); 
  174. 其中,cursor_->value()获取序列化后的字符串。datum.ParseFromString()方法对字符串进行反序列化。 
  175. 最后,要将cursor_向前推进: 
  176. cursor_->Next(); 
  177. if (!cursor_->valid()) { 
  178.   DLOG(INFO) << "Restarting data prefetching from start." 
  179.       cursor_->SeekToFirst(); 
  180. 如果cursor->valid()返回false,说明数据库已经遍历到头,这时需要将cursor_重置回数据库开头。不支持样本随机排序应该是DataLayer的致命弱点。如果数据库的key能够统一,其实可以通过对key随机枚举的方式实现。所以caffe定义了一个随机生成器RNG。 
  181. */  
  182. template<typename Dtype>  
  183. void DataTransformer<Dtype>::Transform(const Datum& datum,  
  184.                                        Dtype* transformed_data) {  
  185.   // 参考TransformationParameter的定义  
  186.   const string& data = datum.data();  
  187.   const int datum_channels = datum.channels();//数据的channel  
  188.   const int datum_height = datum.height();//数据的行数  
  189.   const int datum_width = datum.width();// 数据的列数  
  190.   
  191.   const int crop_size = param_.crop_size();// crop大小  
  192.   const Dtype scale = param_.scale();// 缩放比例  
  193.   const bool do_mirror = param_.mirror() && Rand(2);// 该参数用于在镜像位置对数据处理  
  194.   const bool has_mean_file = param_.has_mean_file();// 是否有均值文件  
  195.   const bool has_uint8 = data.size() > 0;// 数据是否为uint8还是float类型的  
  196.   const bool has_mean_values = mean_values_.size() > 0;// 是否有每个channel的均值  
  197.   
  198.   // 检查合法性  
  199.   CHECK_GT(datum_channels, 0);  
  200.   CHECK_GE(datum_height, crop_size);  
  201.   CHECK_GE(datum_width, crop_size);  
  202.   
  203.   Dtype* mean = NULL;  
  204. /* 
  205. 前面有介绍这一部分CHECK内容,glog提供了多个便利的宏来处理特定关系的判定。具体有: 
  206. 1,判定大小关系 
  207. CHECK_EQ, CHECK_NE, CHECK_LE, CHECK_LT, CHECK_GE, CHECK_GT,使用这些宏需要注意类型一致,如果出现类型不一致的,可使用static_cast转换。 
  208. 2,判定指针是否为空 
  209. CHECK_NOTNULL(some_ptr),可用于对象初始化的时候。 
  210. 3,判定字符串是否相等 
  211. CHECK_STREQ, CHECK_STRNE, CHECK_STRCASEEQ,CHECK_STRCASENE。可进行大小写敏感或不敏感字符串来分别判定。 
  212. 4, 判定浮点是否相等或相近 
  213. CHECK_DOUBLE_EQ,CHECK_NEAR。这两个宏都需要指定一个可容忍的偏差上限。 
  214. */  
  215.   if (has_mean_file) {// 检查mean_file是否与数据的参数一致  
  216.     CHECK_EQ(datum_channels, data_mean_.channels());  
  217.     CHECK_EQ(datum_height, data_mean_.height());  
  218.     CHECK_EQ(datum_width, data_mean_.width());  
  219.     mean = data_mean_.mutable_cpu_data();  
  220.   }  
  221.   if (has_mean_values) {  
  222.     CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<  
  223.      "Specify either 1 mean_value or as many as channels: " << datum_channels;  
  224.     if (datum_channels > 1 && mean_values_.size() == 1) {  
  225.       // Replicate the mean_value for simplicity  
  226.       for (int c = 1; c < datum_channels; ++c) {  
  227.         mean_values_.push_back(mean_values_[0]);  
  228.       }  
  229.     }  
  230.   }  
  231.   
  232.   int height = datum_height;  
  233.   int width = datum_width;  
  234.   
  235.   // 根据是否需要crop来生成h_off和w_off  
  236.   int h_off = 0;  
  237.   int w_off = 0;  
  238.   if (crop_size) {// 如果crop_size不为0  
  239.     height = crop_size;  
  240.     width = crop_size;  
  241.     // We only do random crop when we do training.  
  242.     // 在训练的时候随机crop图像块,这里需要自己实现Rand这个函数来确定是如何随机的  
  243.     if (phase_ == TRAIN) {  
  244.       h_off = Rand(datum_height - crop_size + 1);// 产生从0到datum_height - crop_size的随机数  
  245.       w_off = Rand(datum_width - crop_size + 1);  
  246.     } else {// 测试的时候不用随机,取图像的中心  
  247.       h_off = (datum_height - crop_size) / 2;  
  248.       w_off = (datum_width - crop_size) / 2;  
  249.     }  
  250.   }  
  251.   
  252.   // 对数据进行变换,主要是将原来的像素值减去均值,然后乘以scale这么一个操作  
  253.   // 如果需要crop则最终转换的Blob的大小即为crop*crop  
  254.   // 如果不是,则最终的Blob大小即为datum_height*datum_width  
  255.   Dtype datum_element;  
  256.   int top_index, data_index;  
  257.   for (int c = 0; c < datum_channels; ++c) {  
  258.     for (int h = 0; h < height; ++h) {  
  259.       for (int w = 0; w < width; ++w) {  
  260.         data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;// 获取数据的索引,我不是很明白怎么计算的?  
  261.         if (do_mirror) {// 是否需要在镜像位置转换  
  262.           top_index = (c * height + h) * width + (width - 1 - w);//在宽这个坐标上做文章,来实现镜像  
  263.         } else {//  
  264.           top_index = (c * height + h) * width + w;  
  265.         }  
  266.         if (has_uint8) {// 数据如果是uint8则进行转换  
  267.           datum_element =  
  268.             static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));  
  269.         } else {// 否则就是float  
  270.           datum_element = datum.float_data(data_index);  
  271.         }  
  272.         if (has_mean_file) {// 如果有mean_file,则原来的像素值减去均值,然后乘以scale  
  273.           transformed_data[top_index] =  
  274.             (datum_element - mean[data_index]) * scale;  
  275.         } else {  
  276.           if (has_mean_values) {// 否则减去该channel的均值(每个channel有其一个均值),然后乘以scale  
  277.             transformed_data[top_index] =  
  278.               (datum_element - mean_values_[c]) * scale;  
  279.           } else {// 否则如果没有均值那么就直接乘以scale即可  
  280.             transformed_data[top_index] = datum_element * scale;  
  281.           }  
  282.         }  
  283.       }  
  284.     }  
  285.   }  
  286. }  
  287.   
  288.   
  289. template<typename Dtype>  
  290. void DataTransformer<Dtype>::Transform(const Datum& datum,  
  291.                                        Blob<Dtype>* transformed_blob) {  
  292.   // If datum is encoded, decoded and transform the cv::image.  
  293.   if (datum.encoded()) {//  检查是否编码了,如果是则解码  
  294. #ifdef USE_OPENCV  
  295.     // 先检查是不是两个属性都设置, 如果是则说明参数设置有误  
  296.     CHECK(!(param_.force_color() && param_.force_gray()))  
  297.         << "cannot set both force_color and force_gray";  
  298.     cv::Mat cv_img;  
  299.     if (param_.force_color() || param_.force_gray()) {  
  300.         // 如果强制彩色或者强制灰度图像一个成立则使用DecodeDatumToCVMat解码  
  301.     // If force_color then decode in color otherwise decode in gray.  
  302.       cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
  303.     } else {// 否则使用DecodeDatumToCVMatNative解码  
  304.       cv_img = DecodeDatumToCVMatNative(datum);  
  305.     }  
  306.     // Transform the cv::image into blob.  
  307.     // 变换  
  308.     return Transform(cv_img, transformed_blob);  
  309. #else  
  310.     LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
  311. #endif  // USE_OPENCV  
  312.   } else {// 如果没有编码则,检查force_color和force_gray是否设置,如果设置则不合法,因为该选项只适合于编码后的数据  
  313.     if (param_.force_color() || param_.force_gray()) {  
  314.       LOG(ERROR) << "force_color and force_gray only for encoded datum";  
  315.     }  
  316.   }  
  317.   
  318.   const int crop_size = param_.crop_size();  
  319.   const int datum_channels = datum.channels();  
  320.   const int datum_height = datum.height();  
  321.   const int datum_width = datum.width();  
  322.   
  323.   // Check dimensions.  
  324.   const int channels = transformed_blob->channels();  
  325.   const int height = transformed_blob->height();  
  326.   const int width = transformed_blob->width();  
  327.   const int num = transformed_blob->num();  
  328.   
  329.   CHECK_EQ(channels, datum_channels);  
  330.   CHECK_LE(height, datum_height);  
  331.   CHECK_LE(width, datum_width);  
  332.   CHECK_GE(num, 1);  
  333.   
  334.   if (crop_size) {  
  335.     CHECK_EQ(crop_size, height);  
  336.     CHECK_EQ(crop_size, width);  
  337.   } else {  
  338.     CHECK_EQ(datum_height, height);  
  339.     CHECK_EQ(datum_width, width);  
  340.   }  
  341.   // 继续变换数据  
  342.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  343.   Transform(datum, transformed_data);  
  344. }  
  345.   
  346. template<typename Dtype>  
  347. void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,  
  348.                                        Blob<Dtype>* transformed_blob) {  
  349.   const int datum_num = datum_vector.size();  
  350.   // 变换到的目标blob的形状  
  351.   const int num = transformed_blob->num();  
  352.   const int channels = transformed_blob->channels();  
  353.   const int height = transformed_blob->height();  
  354.   const int width = transformed_blob->width();  
  355.   
  356.   CHECK_GT(datum_num, 0) << "There is no datum to add";  
  357.   CHECK_LE(datum_num, num) <<  
  358.     "The size of datum_vector must be no greater than transformed_blob->num()";  
  359.   // 新建一个uni_blob,里面只有一个batch  
  360.   Blob<Dtype> uni_blob(1, channels, height, width);  
  361.   for (int item_id = 0; item_id < datum_num; ++item_id) {  
  362.     int offset = transformed_blob->offset(item_id);  
  363.     uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
  364.     Transform(datum_vector[item_id], &uni_blob);  
  365.   }  
  366. }  
  367.   
  368. #ifdef USE_OPENCV  
  369. template<typename Dtype>  
  370. void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,  
  371.                                        Blob<Dtype>* transformed_blob) {  
  372.   // 获取mat的参数  
  373.   const int mat_num = mat_vector.size();  
  374.   const int num = transformed_blob->num();  
  375.   const int channels = transformed_blob->channels();  
  376.   const int height = transformed_blob->height();  
  377.   const int width = transformed_blob->width();  
  378.   
  379.   CHECK_GT(mat_num, 0) << "There is no MAT to add";  
  380.   CHECK_EQ(mat_num, num) <<  
  381.     "The size of mat_vector must be equals to transformed_blob->num()";  
  382.   //  同上  
  383.   Blob<Dtype> uni_blob(1, channels, height, width);  
  384.   for (int item_id = 0; item_id < mat_num; ++item_id) {  
  385.     int offset = transformed_blob->offset(item_id);  
  386.     uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);  
  387.     Transform(mat_vector[item_id], &uni_blob);  
  388.   }  
  389. }  
  390.   
  391. // 如果是图像的话,需要减去均值乘以scale,判断是不是需要做镜像处理  
  392. // 逻辑与前面类似  
  393. template<typename Dtype>  
  394. void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,  
  395.                                        Blob<Dtype>* transformed_blob) {  
  396.   const int crop_size = param_.crop_size();  
  397.   const int img_channels = cv_img.channels();  
  398.   const int img_height = cv_img.rows;  
  399.   const int img_width = cv_img.cols;  
  400.   
  401.   // Check dimensions.  
  402.   const int channels = transformed_blob->channels();  
  403.   const int height = transformed_blob->height();  
  404.   const int width = transformed_blob->width();  
  405.   const int num = transformed_blob->num();  
  406.   
  407.   CHECK_EQ(channels, img_channels);  
  408.   CHECK_LE(height, img_height);  
  409.   CHECK_LE(width, img_width);  
  410.   CHECK_GE(num, 1);  
  411.   
  412.   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";  
  413.   
  414.   const Dtype scale = param_.scale();  
  415.   const bool do_mirror = param_.mirror() && Rand(2);  
  416.   const bool has_mean_file = param_.has_mean_file();  
  417.   const bool has_mean_values = mean_values_.size() > 0;  
  418.   
  419.   CHECK_GT(img_channels, 0);  
  420.   CHECK_GE(img_height, crop_size);  
  421.   CHECK_GE(img_width, crop_size);  
  422.   
  423.   Dtype* mean = NULL;  
  424.   if (has_mean_file) {  
  425.     CHECK_EQ(img_channels, data_mean_.channels());  
  426.     CHECK_EQ(img_height, data_mean_.height());  
  427.     CHECK_EQ(img_width, data_mean_.width());  
  428.     mean = data_mean_.mutable_cpu_data();  
  429.   }  
  430.   if (has_mean_values) {  
  431.     CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<  
  432.      "Specify either 1 mean_value or as many as channels: " << img_channels;  
  433.     if (img_channels > 1 && mean_values_.size() == 1) {  
  434.       // Replicate the mean_value for simplicity  
  435.       for (int c = 1; c < img_channels; ++c) {  
  436.         mean_values_.push_back(mean_values_[0]);  
  437.       }  
  438.     }  
  439.   }  
  440.   
  441.   int h_off = 0;  
  442.   int w_off = 0;  
  443.   cv::Mat cv_cropped_img = cv_img;  
  444.   if (crop_size) {  
  445.     CHECK_EQ(crop_size, height);  
  446.     CHECK_EQ(crop_size, width);  
  447.     // We only do random crop when we do training.  
  448.     if (phase_ == TRAIN) {  
  449.       h_off = Rand(img_height - crop_size + 1);  
  450.       w_off = Rand(img_width - crop_size + 1);  
  451.     } else {  
  452.       h_off = (img_height - crop_size) / 2;  
  453.       w_off = (img_width - crop_size) / 2;  
  454.     }  
  455.     cv::Rect roi(w_off, h_off, crop_size, crop_size);  
  456.     cv_cropped_img = cv_img(roi);  
  457.   } else {  
  458.     CHECK_EQ(img_height, height);  
  459.     CHECK_EQ(img_width, width);  
  460.   }  
  461.   
  462.   CHECK(cv_cropped_img.data);  
  463.   
  464.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  465.   int top_index;  
  466.   for (int h = 0; h < height; ++h) {  
  467.     const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  
  468.     int img_index = 0;  
  469.     for (int w = 0; w < width; ++w) {  
  470.       for (int c = 0; c < img_channels; ++c) {  
  471.         if (do_mirror) {  
  472.           top_index = (c * height + h) * width + (width - 1 - w);  
  473.         } else {  
  474.           top_index = (c * height + h) * width + w;  
  475.         }  
  476.         // int top_index = (c * height + h) * width + w;  
  477.         Dtype pixel = static_cast<Dtype>(ptr[img_index++]);  
  478.         if (has_mean_file) {  
  479.           int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;  
  480.           transformed_data[top_index] =  
  481.             (pixel - mean[mean_index]) * scale;  
  482.         } else {  
  483.           if (has_mean_values) {  
  484.             transformed_data[top_index] =  
  485.               (pixel - mean_values_[c]) * scale;  
  486.           } else {  
  487.             transformed_data[top_index] = pixel * scale;  
  488.           }  
  489.         }  
  490.       }  
  491.     }  
  492.   }  
  493. }  
  494. #endif  // USE_OPENCV  
  495.   
  496. template<typename Dtype>  
  497. void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,  
  498.                                        Blob<Dtype>* transformed_blob) {  
  499.   const int crop_size = param_.crop_size();  
  500.   const int input_num = input_blob->num();  
  501.   const int input_channels = input_blob->channels();  
  502.   const int input_height = input_blob->height();  
  503.   const int input_width = input_blob->width();  
  504.   
  505.   if (transformed_blob->count() == 0) {  
  506.     // Initialize transformed_blob with the right shape.  
  507.     if (crop_size) {  
  508.       transformed_blob->Reshape(input_num, input_channels,  
  509.                                 crop_size, crop_size);  
  510.     } else {  
  511.       transformed_blob->Reshape(input_num, input_channels,  
  512.                                 input_height, input_width);  
  513.     }  
  514.   }  
  515.   
  516.   const int num = transformed_blob->num();  
  517.   const int channels = transformed_blob->channels();  
  518.   const int height = transformed_blob->height();  
  519.   const int width = transformed_blob->width();  
  520.   const int size = transformed_blob->count();  
  521.   
  522.   CHECK_LE(input_num, num);  
  523.   CHECK_EQ(input_channels, channels);  
  524.   CHECK_GE(input_height, height);  
  525.   CHECK_GE(input_width, width);  
  526.   
  527.   
  528.   const Dtype scale = param_.scale();  
  529.   const bool do_mirror = param_.mirror() && Rand(2);  
  530.   const bool has_mean_file = param_.has_mean_file();  
  531.   const bool has_mean_values = mean_values_.size() > 0;  
  532.   
  533.   int h_off = 0;  
  534.   int w_off = 0;  
  535.   if (crop_size) {  
  536.     CHECK_EQ(crop_size, height);  
  537.     CHECK_EQ(crop_size, width);  
  538.     // We only do random crop when we do training.  
  539.     if (phase_ == TRAIN) {  
  540.       h_off = Rand(input_height - crop_size + 1);  
  541.       w_off = Rand(input_width - crop_size + 1);  
  542.     } else {  
  543.       h_off = (input_height - crop_size) / 2;  
  544.       w_off = (input_width - crop_size) / 2;  
  545.     }  
  546.   } else {  
  547.     CHECK_EQ(input_height, height);  
  548.     CHECK_EQ(input_width, width);  
  549.   }  
  550.   
  551.   // 如果有均值文件则  
  552.   Dtype* input_data = input_blob->mutable_cpu_data();  
  553.   if (has_mean_file) {  
  554.     CHECK_EQ(input_channels, data_mean_.channels());  
  555.     CHECK_EQ(input_height, data_mean_.height());  
  556.     CHECK_EQ(input_width, data_mean_.width());  
  557.     for (int n = 0; n < input_num; ++n) {  
  558.       int offset = input_blob->offset(n);  
  559.       /* 
  560.          template <typename Dtype> 
  561.        void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); 
  562.        math_function中定义的caffe_sub目的是矩阵相减input_data(以offset开始的矩阵) = input_data(以offset开始的矩阵) - data_mean_ 
  563.     */  
  564.       caffe_sub(data_mean_.count(), input_data + offset,  
  565.             data_mean_.cpu_data(), input_data + offset);  
  566.     }  
  567.   }  
  568.   // 如果每个channel有均值则  
  569.   if (has_mean_values) {  
  570.     CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<  
  571.      "Specify either 1 mean_value or as many as channels: " << input_channels;  
  572.     if (mean_values_.size() == 1) {  
  573.       caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);  
  574.     } else {  
  575.       for (int n = 0; n < input_num; ++n) {  
  576.         for (int c = 0; c < input_channels; ++c) {  
  577.           int offset = input_blob->offset(n, c);  
  578.           // 给nput_data[offset]地址开始的每一个元素加上一个-mean_values_[c]  
  579.           caffe_add_scalar(input_height * input_width, -(mean_values_[c]),  
  580.             input_data + offset);  
  581.         }  
  582.       }  
  583.     }  
  584.   }  
  585.   
  586.   // 如果啥均值都没有则直接复制  
  587.   Dtype* transformed_data = transformed_blob->mutable_cpu_data();  
  588.   
  589.   for (int n = 0; n < input_num; ++n) {  
  590.     int top_index_n = n * channels;  
  591.     int data_index_n = n * channels;  
  592.     for (int c = 0; c < channels; ++c) {  
  593.       int top_index_c = (top_index_n + c) * height;  
  594.       int data_index_c = (data_index_n + c) * input_height + h_off;  
  595.       for (int h = 0; h < height; ++h) {  
  596.         int top_index_h = (top_index_c + h) * width;  
  597.         int data_index_h = (data_index_c + h) * input_width + w_off;  
  598.         if (do_mirror) {  
  599.           int top_index_w = top_index_h + width - 1;  
  600.           for (int w = 0; w < width; ++w) {  
  601.             transformed_data[top_index_w-w] = input_data[data_index_h + w];  
  602.           }  
  603.         } else {  
  604.           for (int w = 0; w < width; ++w) {  
  605.             transformed_data[top_index_h + w] = input_data[data_index_h + w];  
  606.           }  
  607.         }  
  608.       }  
  609.     }  
  610.   }  
  611.   if (scale != Dtype(1)) {  
  612.     DLOG(INFO) << "Scale: " << scale;  
  613.     caffe_scal(size, scale, transformed_data);  
  614.   }  
  615. }  
  616.   
  617. template<typename Dtype>  
  618. vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {  
  619.   if (datum.encoded()) {  
  620. #ifdef USE_OPENCV // 如果使用OpenCV则可以用先转换为CVMat,然后在推断blob的形状  
  621.     CHECK(!(param_.force_color() && param_.force_gray()))  
  622.         << "cannot set both force_color and force_gray";  
  623.     cv::Mat cv_img;  
  624.     if (param_.force_color() || param_.force_gray()) {  
  625.     // If force_color then decode in color otherwise decode in gray.  
  626.       cv_img = DecodeDatumToCVMat(datum, param_.force_color());  
  627.     } else {  
  628.       cv_img = DecodeDatumToCVMatNative(datum);  
  629.     }  
  630.     // InferBlobShape using the cv::image.  
  631.     return InferBlobShape(cv_img);  
  632. #else  
  633.     LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";  
  634. #endif  // USE_OPENCV  
  635.   }  
  636.   
  637.   // 否则直接粗暴地从datum里面获取形状的数据  
  638.   const int crop_size = param_.crop_size();  
  639.   const int datum_channels = datum.channels();  
  640.   const int datum_height = datum.height();  
  641.   const int datum_width = datum.width();  
  642.   // Check dimensions.  
  643.   CHECK_GT(datum_channels, 0);  
  644.   CHECK_GE(datum_height, crop_size);  
  645.   CHECK_GE(datum_width, crop_size);  
  646.   // Build BlobShape.  
  647.   vector<int> shape(4);  
  648.   shape[0] = 1;  
  649.   shape[1] = datum_channels;  
  650.   shape[2] = (crop_size)? crop_size: datum_height;  
  651.   shape[3] = (crop_size)? crop_size: datum_width;  
  652.   return shape;  
  653. }  
  654.   
  655. template<typename Dtype>  
  656. vector<int> DataTransformer<Dtype>::InferBlobShape(  
  657.     const vector<Datum> & datum_vector) {  
  658.   const int num = datum_vector.size();  
  659.   CHECK_GT(num, 0) << "There is no datum to in the vector";  
  660.   // Use first datum in the vector to InferBlobShape.  
  661.   // 使用第一个来进行推断  
  662.   vector<int> shape = InferBlobShape(datum_vector[0]);  
  663.   // Adjust num to the size of the vector.  
  664.   shape[0] = num;  
  665.   return shape;  
  666. }  
  667.   
  668. #ifdef USE_OPENCV  
  669. // 如果使用OpenCV  
  670. // 使用CVMat中的信息来推断形状  
  671. template<typename Dtype>  
  672. vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {  
  673.   const int crop_size = param_.crop_size();  
  674.   const int img_channels = cv_img.channels();  
  675.   const int img_height = cv_img.rows;  
  676.   const int img_width = cv_img.cols;  
  677.   // Check dimensions.  
  678.   CHECK_GT(img_channels, 0);  
  679.   CHECK_GE(img_height, crop_size);  
  680.   CHECK_GE(img_width, crop_size);  
  681.   // Build BlobShape.  
  682.   vector<int> shape(4);  
  683.   shape[0] = 1;  
  684.   shape[1] = img_channels;  
  685.   shape[2] = (crop_size)? crop_size: img_height;  
  686.   shape[3] = (crop_size)? crop_size: img_width;  
  687.   return shape;  
  688. }  
  689.   
  690. template<typename Dtype>  
  691. vector<int> DataTransformer<Dtype>::InferBlobShape(  
  692.     const vector<cv::Mat> & mat_vector) {  
  693.   const int num = mat_vector.size();  
  694.   CHECK_GT(num, 0) << "There is no cv_img to in the vector";  
  695.   // Use first cv_img in the vector to InferBlobShape.  
  696.   // 使用第一个来推断  
  697.   vector<int> shape = InferBlobShape(mat_vector[0]);  
  698.   // Adjust num to the size of the vector.  
  699.   shape[0] = num;  
  700.   return shape;  
  701. }  
  702. #endif  // USE_OPENCV  
  703.   
  704. // 初始化随机数种子  
  705. template <typename Dtype>  
  706. void DataTransformer<Dtype>::InitRand() {  
  707.   // 要么需要镜像要么训练阶段和需要crop同时满足的情况下才初始化随机数种子  
  708.   const bool needs_rand = param_.mirror() ||  
  709.       (phase_ == TRAIN && param_.crop_size());  
  710.   if (needs_rand) {  
  711.     const unsigned int rng_seed = caffe_rng_rand();// 获得随机数种子(通过熵池或者时间生成种子)  
  712.     rng_.reset(new Caffe::RNG(rng_seed));//初始化随机数种子并实例化随机数生成器  
  713.   } else {  
  714.     rng_.reset();//否则随机数生成器设置为空  
  715.   }  
  716. }  
  717.   
  718. // 产生从0到n的随机数  
  719. template <typename Dtype>  
  720. int DataTransformer<Dtype>::Rand(int n) {  
  721.   CHECK(rng_);  
  722.   CHECK_GT(n, 0);  
  723.   caffe::rng_t* rng =  
  724.       static_cast<caffe::rng_t*>(rng_->generator());  
  725.   return ((*rng)() % n);  
  726. }  
  727.   
  728. INSTANTIATE_CLASS(DataTransformer);  
  729. /* 
  730. 初始化类的宏定义是这样的,前面有讲过,这里再给出来 
  731. #define INSTANTIATE_CLASS(classname) \ 
  732.   char gInstantiationGuard##classname; \ 
  733.   template class classname<float>; \ 
  734.   template class classname<double> 
  735. */  
  736.   
  737. }  // namespace caffe  
posted @ 2016-04-12 09:55  菜鸡一枚  阅读(9672)  评论(0编辑  收藏  举报