梳理caffe代码data_reader(十一)

梳理caffe代码data_reader(十一)

上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。

顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。

我们就看看数据是怎么进行排队的?

头文件:

 

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. #ifndef CAFFE_DATA_READER_HPP_  
  2. #define CAFFE_DATA_READER_HPP_  
  3.   
  4. #include <map>  
  5. #include <string>  
  6. #include <vector>  
  7.   
  8. #include "caffe/common.hpp"  
  9. #include "caffe/internal_thread.hpp"  
  10. #include "caffe/util/blocking_queue.hpp"  
  11. #include "caffe/util/db.hpp"  
  12.   
  13. namespace caffe {  
  14.   
  15. /** 
  16.  * @brief Reads data from a source to queues available to data layers. 
  17.  * A single reading thread is created per source, even if multiple solvers 
  18.  * are running in parallel, e.g. for multi-GPU training. This makes sure 
  19.  * databases are read sequentially, and that each solver accesses a different 
  20.  * subset of the database. Data is distributed to solvers in a round-robin 
  21.  * way to keep parallel training deterministic. 
  22.  */  
  23. /* 
  24. 从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。 
  25. */  
  26. class DataReader {  
  27.  public:  
  28.   explicit DataReader(const LayerParameter& param);  
  29.   ~DataReader();  
  30. //  
  31.   inline BlockingQueue<Datum*>& free() const {  
  32.     return queue_pair_->free_;  
  33.   }  
  34.   inline BlockingQueue<Datum*>& full() const {  
  35.     return queue_pair_->full_;  
  36.   }  
  37.   
  38.  protected:  
  39.   // Queue pairs are shared between a body and its readers  
  40.   class QueuePair {  
  41.    public:  
  42.     explicit QueuePair(int size);  
  43.     ~QueuePair();  
  44. //定义了两个阻塞队列free_和full_  
  45.     BlockingQueue<Datum*> free_;  
  46.     BlockingQueue<Datum*> full_;  
  47.   
  48.   DISABLE_COPY_AND_ASSIGN(QueuePair);  
  49.   };  
  50.   
  51.   // A single body is created per source  
  52. //继承InternalThread 这个类的  
  53.   class Body : public InternalThread {  
  54.    public:  
  55.     explicit Body(const LayerParameter& param);  
  56.     virtual ~Body();  
  57.   
  58.    protected:  
  59. //重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数  
  60.     void InternalThreadEntry();  
  61.     void read_one(db::Cursor* cursor, QueuePair* qp);  
  62.   
  63.     const LayerParameter param_;  
  64.     BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;  
  65. //内部有DataReader的友元  
  66.     friend class DataReader;  
  67.   
  68.   DISABLE_COPY_AND_ASSIGN(Body);  
  69.   };  
  70.   
  71.   // A source is uniquely identified by its layer name + path, in case  
  72.   // the same database is read from two different locations in the net.  
  73.   static inline string source_key(const LayerParameter& param) {  
  74.     return param.name() + ":" + param.data_param().source();  
  75.   }  
  76.   
  77.   const shared_ptr<QueuePair> queue_pair_;  
  78.   shared_ptr<Body> body_;  
  79.   
  80.   static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;  
  81.   
  82. DISABLE_COPY_AND_ASSIGN(DataReader);  
  83. };  
  84.   
  85. }  // namespace caffe  
  86.   
  87. #endif  // CAFFE_DATA_READER_HPP_  

实现部分:

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. #include <boost/thread.hpp>  
  2. #include <map>  
  3. #include <string>  
  4. #include <vector>  
  5.   
  6. #include "caffe/common.hpp"  
  7. #include "caffe/data_reader.hpp"  
  8. #include "caffe/layers/data_layer.hpp"  
  9. #include "caffe/proto/caffe.pb.h"  
  10.   
  11. namespace caffe {  
  12.   
  13. using boost::weak_ptr;  
  14.   
  15. map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;  
  16. static boost::mutex bodies_mutex_;  
  17.   
  18. DataReader::DataReader(const LayerParameter& param)  
  19.     : queue_pair_(new QueuePair(  //  
  20.         param.data_param().prefetch() * param.data_param().batch_size())) {  
  21.   // Get or create a body  
  22.   boost::mutex::scoped_lock lock(bodies_mutex_);  
  23.   string key = source_key(param);  
  24.   weak_ptr<Body>& weak = bodies_[key];  
  25.   body_ = weak.lock();  
  26.   if (!body_) {  
  27.     body_.reset(new Body(param));  
  28.     bodies_[key] = weak_ptr<Body>(body_);  
  29.   }  
  30.   body_->new_queue_pairs_.push(queue_pair_);  
  31. }  
  32.   
  33. DataReader::~DataReader() {  
  34.   string key = source_key(body_->param_);  
  35.   body_.reset();  
  36.   boost::mutex::scoped_lock lock(bodies_mutex_);  
  37.   if (bodies_[key].expired()) {  
  38.     bodies_.erase(key);  
  39.   }  
  40. }  
  41.   
  42. //根据给定的size初始化的若干个Datum的实例到free里面  
  43.   
  44. DataReader::QueuePair::QueuePair(int size) {  
  45.   // Initialize the free queue with requested number of datums  
  46.   for (int i = 0; i < size; ++i) {  
  47.     free_.push(new Datum());  
  48.   }  
  49. }  
  50. //将full_和free_这两个队列里面的Datum对象全部delete。  
  51. DataReader::QueuePair::~QueuePair() {  
  52.   Datum* datum;  
  53.   while (free_.try_pop(&datum)) {  
  54.     delete datum;  
  55.   }  
  56.   while (full_.try_pop(&datum)) {  
  57.     delete datum;  
  58.   }  
  59. }  
  60. //Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程  
  61. DataReader::Body::Body(const LayerParameter& param)  
  62.     : param_(param),  
  63.       new_queue_pairs_() {  
  64.   StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容  
  65. }  
  66. // 析构,停止线程  
  67. DataReader::Body::~Body() {  
  68.   StopInternalThread();  
  69. }  
  70.   
  71. // 自己实现的需要执行的函数  
  72. // 首先打开数据库,然后设置游标,然后设置QueuePair指针容器  
  73. void DataReader::Body::InternalThreadEntry() {  
  74.   // 获取所给定的数据源的类型来得到DB的指针  
  75.   shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));  
  76.   // 从网络参数中给定的DB的位置打开DB  
  77.   db->Open(param_.data_param().source(), db::READ);  
  78.   // 新建游标指针  
  79.   shared_ptr<db::Cursor> cursor(db->NewCursor());  
  80.   // 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列  
  81.   vector<shared_ptr<QueuePair> > qps;  
  82.   try {  
  83.     // 根据网络参数的阶段来设置solver_count  
  84.     int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;  
  85.   
  86.     // To ensure deterministic runs, only start running once all solvers  
  87.     // are ready. But solvers need to peek on one item during initialization,  
  88.     // so read one item, then wait for the next solver.  
  89.     for (int i = 0; i < solver_count; ++i) {  
  90.       shared_ptr<QueuePair> qp(new_queue_pairs_.pop());  
  91.       read_one(cursor.get(), qp.get());// 读取一个数据  
  92.       qps.push_back(qp);压入  
  93.     }  
  94.     // Main loop  
  95.     while (!must_stop()) {  
  96.       for (int i = 0; i < solver_count; ++i) {  
  97.         read_one(cursor.get(), qps[i].get());  
  98.       }  
  99.       // Check no additional readers have been created. This can happen if  
  100.       // more than one net is trained at a time per process, whether single  
  101.       // or multi solver. It might also happen if two data layers have same  
  102.       // name and same source.  
  103.       CHECK_EQ(new_queue_pairs_.size(), 0);  
  104.     }  
  105.   } catch (boost::thread_interrupted&) {  
  106.     // Interrupted exception is expected on shutdown  
  107.   }  
  108. }  
  109.   
  110. // 从数据库中获取一个数据  
  111. void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {  
  112.   // 从QueuePair中的free_队列pop出一个  
  113.   Datum* datum = qp->free_.pop();  
  114.   // TODO deserialize in-place instead of copy?  
  115.   // 然后解析cursor中的值  
  116.   datum->ParseFromString(cursor->value());  
  117.   // 然后压入QueuePair中的full_队列  
  118.   qp->full_.push(datum);  
  119.   
  120.   // go to the next iter  
  121.   // 游标指向下一个  
  122.   cursor->Next();  
  123.   if (!cursor->valid()) {  
  124.     DLOG(INFO) << "Restarting data prefetching from start.";  
  125.     cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置  
  126.   }  
  127. }  
  128.   
  129. }  // namespace caffe  

数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。

posted @ 2016-04-11 16:58  菜鸡一枚  阅读(1453)  评论(0编辑  收藏  举报