梳理caffe代码data_reader(十一)
上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。
顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。
我们就看看数据是怎么进行排队的?
头文件:
- #ifndef CAFFE_DATA_READER_HPP_
- #define CAFFE_DATA_READER_HPP_
- #include <map>
- #include <string>
- #include <vector>
- #include "caffe/common.hpp"
- #include "caffe/internal_thread.hpp"
- #include "caffe/util/blocking_queue.hpp"
- #include "caffe/util/db.hpp"
- namespace caffe {
- /**
- * @brief Reads data from a source to queues available to data layers.
- * A single reading thread is created per source, even if multiple solvers
- * are running in parallel, e.g. for multi-GPU training. This makes sure
- * databases are read sequentially, and that each solver accesses a different
- * subset of the database. Data is distributed to solvers in a round-robin
- * way to keep parallel training deterministic.
- */
- /*
- 从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。
- */
- class DataReader {
- public:
- explicit DataReader(const LayerParameter& param);
- ~DataReader();
- //
- inline BlockingQueue<Datum*>& free() const {
- return queue_pair_->free_;
- }
- inline BlockingQueue<Datum*>& full() const {
- return queue_pair_->full_;
- }
- protected:
- // Queue pairs are shared between a body and its readers
- class QueuePair {
- public:
- explicit QueuePair(int size);
- ~QueuePair();
- //定义了两个阻塞队列free_和full_
- BlockingQueue<Datum*> free_;
- BlockingQueue<Datum*> full_;
- DISABLE_COPY_AND_ASSIGN(QueuePair);
- };
- // A single body is created per source
- //继承InternalThread 这个类的
- class Body : public InternalThread {
- public:
- explicit Body(const LayerParameter& param);
- virtual ~Body();
- protected:
- //重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数
- void InternalThreadEntry();
- void read_one(db::Cursor* cursor, QueuePair* qp);
- const LayerParameter param_;
- BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;
- //内部有DataReader的友元
- friend class DataReader;
- DISABLE_COPY_AND_ASSIGN(Body);
- };
- // A source is uniquely identified by its layer name + path, in case
- // the same database is read from two different locations in the net.
- static inline string source_key(const LayerParameter& param) {
- return param.name() + ":" + param.data_param().source();
- }
- const shared_ptr<QueuePair> queue_pair_;
- shared_ptr<Body> body_;
- static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;
- DISABLE_COPY_AND_ASSIGN(DataReader);
- };
- } // namespace caffe
- #endif // CAFFE_DATA_READER_HPP_
实现部分:
- #include <boost/thread.hpp>
- #include <map>
- #include <string>
- #include <vector>
- #include "caffe/common.hpp"
- #include "caffe/data_reader.hpp"
- #include "caffe/layers/data_layer.hpp"
- #include "caffe/proto/caffe.pb.h"
- namespace caffe {
- using boost::weak_ptr;
- map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;
- static boost::mutex bodies_mutex_;
- DataReader::DataReader(const LayerParameter& param)
- : queue_pair_(new QueuePair( //
- param.data_param().prefetch() * param.data_param().batch_size())) {
- // Get or create a body
- boost::mutex::scoped_lock lock(bodies_mutex_);
- string key = source_key(param);
- weak_ptr<Body>& weak = bodies_[key];
- body_ = weak.lock();
- if (!body_) {
- body_.reset(new Body(param));
- bodies_[key] = weak_ptr<Body>(body_);
- }
- body_->new_queue_pairs_.push(queue_pair_);
- }
- DataReader::~DataReader() {
- string key = source_key(body_->param_);
- body_.reset();
- boost::mutex::scoped_lock lock(bodies_mutex_);
- if (bodies_[key].expired()) {
- bodies_.erase(key);
- }
- }
- //根据给定的size初始化的若干个Datum的实例到free里面
- DataReader::QueuePair::QueuePair(int size) {
- // Initialize the free queue with requested number of datums
- for (int i = 0; i < size; ++i) {
- free_.push(new Datum());
- }
- }
- //将full_和free_这两个队列里面的Datum对象全部delete。
- DataReader::QueuePair::~QueuePair() {
- Datum* datum;
- while (free_.try_pop(&datum)) {
- delete datum;
- }
- while (full_.try_pop(&datum)) {
- delete datum;
- }
- }
- //Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程
- DataReader::Body::Body(const LayerParameter& param)
- : param_(param),
- new_queue_pairs_() {
- StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容
- }
- // 析构,停止线程
- DataReader::Body::~Body() {
- StopInternalThread();
- }
- // 自己实现的需要执行的函数
- // 首先打开数据库,然后设置游标,然后设置QueuePair指针容器
- void DataReader::Body::InternalThreadEntry() {
- // 获取所给定的数据源的类型来得到DB的指针
- shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));
- // 从网络参数中给定的DB的位置打开DB
- db->Open(param_.data_param().source(), db::READ);
- // 新建游标指针
- shared_ptr<db::Cursor> cursor(db->NewCursor());
- // 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列
- vector<shared_ptr<QueuePair> > qps;
- try {
- // 根据网络参数的阶段来设置solver_count
- int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;
- // To ensure deterministic runs, only start running once all solvers
- // are ready. But solvers need to peek on one item during initialization,
- // so read one item, then wait for the next solver.
- for (int i = 0; i < solver_count; ++i) {
- shared_ptr<QueuePair> qp(new_queue_pairs_.pop());
- read_one(cursor.get(), qp.get());// 读取一个数据
- qps.push_back(qp);压入
- }
- // Main loop
- while (!must_stop()) {
- for (int i = 0; i < solver_count; ++i) {
- read_one(cursor.get(), qps[i].get());
- }
- // Check no additional readers have been created. This can happen if
- // more than one net is trained at a time per process, whether single
- // or multi solver. It might also happen if two data layers have same
- // name and same source.
- CHECK_EQ(new_queue_pairs_.size(), 0);
- }
- } catch (boost::thread_interrupted&) {
- // Interrupted exception is expected on shutdown
- }
- }
- // 从数据库中获取一个数据
- void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {
- // 从QueuePair中的free_队列pop出一个
- Datum* datum = qp->free_.pop();
- // TODO deserialize in-place instead of copy?
- // 然后解析cursor中的值
- datum->ParseFromString(cursor->value());
- // 然后压入QueuePair中的full_队列
- qp->full_.push(datum);
- // go to the next iter
- // 游标指向下一个
- cursor->Next();
- if (!cursor->valid()) {
- DLOG(INFO) << "Restarting data prefetching from start.";
- cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置
- }
- }
- } // namespace caffe
数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。