caffe源码解析-solver_factory

caffe源码解析-solver_factory

声明:内容整理自

感谢Ldy和各位博主的无私分享。各位博主已经写的很好,个人做了一些梳理和补充,方便日后回顾。



目录:

  • SolverRegistry 工厂类定义
  • SolverRegistry 工厂类注册Solver子类
  • 回顾命令行接口中train()函数

 

caffe提供了很多优化方法。它们被封装成继承自父类Solver的子类。如SGDSolver,AdamSolver等。为了能够灵活的实例化这些子类,采用了下图的工厂模式。(关于工厂模式可以参考 工厂模式 | 菜鸟教程

先来研究一下工厂类 SolverRegistry。代码主要在include/caffe/solver_factory.hpp文件中。

 

SolverRegistry 工厂类定义:

我们知道工厂类的关键是一个可以实现选择的结构,它根据关键词来实例化相应的子类并返回。例如下图 ShapeFactory就是根据传入的字符串实例化不同的形状子类并返回。

但是代码并没有直接使用上面的分支结构实现选择。而是用了一个map容器,它以"特定Solver的type"为Key,以”特定Solver对应的Creator函数指针"为Value,通过这种键值映射关系也可以实现选择。 (本质上是用查表来替代分支实现选择,空间换时间)

 

1. map容器:

  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
  typedef std::map<string, Creator> CreatorRegistry;

Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>* 。这个函数就是 ”特定Solver对应的Creator函数"。调用这个函数可以就得到实例化的子类。 (注意:容器的类型是被定义为CreatorRegistry)。

PS:为什么把响应的对象直接作为map容器的 Value值呢,这样岂不很直接? 个人理解是因为对象实例化还需要参数,用一个函数封装一下方便实例化时传参。
  static CreatorRegistry& Registry() {
	//静态变量
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }
  
  CreatorRegistry& registry = Registry();

Registry的作用是,调用时返回一个指向 容器(CreatorRegistry)类型的静态变量 g_registry_。说白了g_registry_就是一个容器指针。因为这个变量是static的,所以即使多次调用这个函数,也只会得到一个g_registry_,而且在其他地方修改这个map里的内容,是存储在这个map中的。(看来的确是容器)

接下来的重点就是如何设计 “注册器”(把creator函数注册进map容器),“提取器”(从map容器中根据type提取creator函数)

2. “注册器”AddCreator如下:

// Adds a creator.
  // 添加一个creator
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    registry[type] = creator;
  }

先调用Registry()得到容器,确保容器内不存在要添加的type关键字,然后添加Key为type,Value为Creator函数的元素。

3. "提取器"CreateSolver函数如下:

// Get a solver using a SolverParameter.
  // 静态成员函数,在caffe.cpp里直接调用,返回Solver指针
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
	// string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等)
	// 默认为SGD
    const string& type = param.type();
    // 定义了一个key类型为string,value类型为Creator的map:registry
    // 返回为静态变量
    CreatorRegistry& registry = Registry();


    for (typename CreatorRegistry::iterator iter = registry.begin();
             iter != registry.end(); ++iter)
    {
         std::cout<<"key:"<<iter->first<<"``` "
             <<"value:"<<iter->second<<std::endl;}

    /*
     * 如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,
     * 然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个
     * creator函数,将creator返回的Solver<Dtype>*返回。
     */
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    //通过static的g_registry_[type]获得type对应的solver的creator函数指针
    return registry[type](param);
  }

这个函数先定义了string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等),然后定义了一个key类型为string,value类型为Creator的map:registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回,也就是说调用”提取器“就可以得到实例化之后的特定Solver对象。(果然是工厂)

 

SolverRegistry 工厂类 注册Solver子类:

在include/caffe/solver_factory.hpp文件的核心内容就是上面的工厂类 SolverRegistry。接下来说说如何进行”注册“。注册不就是直接调用AddCreator注册器吗,有什么好讲的。非也非也。我们不妨先按照自己的想法试一试该怎么做(下面是个人理解,如有错误还望指正)。

”注册器“一般在Solver子类(如 SGDsolver)的定义之后调用,以SGDsolver为例。首先,必须封装一个Creator函数,然后调用AddCreator。

  Solver<Dtype>* Creator_SGDSolver( const SolverParameter& param)
  {                                                                          
    return new SGDSolver<Dtype>(param);
  }
  SolverRegistry<Dtype>::AddCreator('SGD', Creator_SGDSolver<float>);
  SolverRegistry<Dtype>::AddCreator('SGD', Creator_SGDSolver<double>);     

 

每一个Solver子类定义的后面都要加上类似的代码,对于Adadelta子类则变成:

Solver<Dtype>* Creator_AdadeltaSolver( const SolverParameter& param)
{                                                                          
    return new AdadeltaSolver<Dtype>(param);
}
  SolverRegistry<Dtype>::AddCreator('Adadelta',Creator_AdadeltaSolver<float>);
SolverRegistry<Dtype>::AddCreator('Adadelta',Creator_AdadeltaSolver<double>);

这两段代码唯一的区别是 "SGD" 和 ”Adadelta“,所以自然想到要用宏来简化代码。

注册一个Solver子类就只需要一条语句:

REGISTER_SOLVER_CLASS(SGD);

不过源代码更加精简:把

  SolverRegistry<Dtype>::AddCreator(#type, Creator_##type##Solver<float>); \
  SolverRegistry<Dtype>::AddCreator(#type, Creator_##type##Solver<double>)  

这两条语句也封装成了一个宏:

先定义了一个SolverRegisterer模板类,它只有一个构造函数,用来调用SolverRegistry<Dtype>::AddCreator。这样做的目的是: ”声明变量“ 触发 构造函数调用,进而”调用AddCreator“。也就是说现在你只需要声明一个 SolverRegisterer类的变量就会取调用AddCreator。

有了这一招,下面就可以定义REGISTER_SOLVER_CREATOR宏。 它分别定义了SolverRegisterer类的float和double类型的static变量(”声明变量“),这会 触发 调用各自的构造函数,而在SolverRegisterer的构造函数中”调用AddCreator"。

所以最终的注册器由原来的:

变成了:

至此我们已经把solver_factory.hpp中的所有代码讲解完毕。

include\caffe\solver_factory.hpp的全部代码

如下:

#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_

#include <map>
#include <string>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

template <typename Dtype>
class Solver;

template <typename Dtype>
class SolverRegistry {
 public:
  //Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型
  //,返回类型为Solver<Dtype>*
  typedef Solver<Dtype>* (*Creator)(const SolverParameter&);

  //
  typedef std::map<string, Creator> CreatorRegistry;

  static CreatorRegistry& Registry() {
	//静态变量
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }

  // Adds a creator.
  // 添加一个creator
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    registry[type] = creator;
  }

  // Get a solver using a SolverParameter.
  // 静态成员函数,在caffe.cpp里直接调用,返回Solver指针
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
	// string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等)
	// 默认为SGD
    const string& type = param.type();
    // 定义了一个key类型为string,value类型为Creator的map:registry
    // 返回为静态变量
    CreatorRegistry& registry = Registry();


    for (typename CreatorRegistry::iterator iter = registry.begin();
             iter != registry.end(); ++iter)
    {
         std::cout<<"key:"<<iter->first<<"``` "
             <<"value:"<<iter->second<<std::endl;}

    /*
     * 如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,
     * 然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个
     * creator函数,将creator返回的Solver<Dtype>*返回。
     */
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    //通过static的g_registry_[type]获得type对应的solver的creator函数指针
    return registry[type](param);
  }

  static vector<string> SolverTypeList() {
    CreatorRegistry& registry = Registry();
    vector<string> solver_types;
    for (typename CreatorRegistry::iterator iter = registry.begin();
         iter != registry.end(); ++iter) {
      solver_types.push_back(iter->first);
    }
    return solver_types;
  }

 private:
  // Solver registry should never be instantiated - everything is done with its
  // static variables.
  // Solver registry不应该被实例化,因为所有的成员都是静态变量
  // 构造函数是私有的,所有成员函数都是静态的,可以通过类调用
  SolverRegistry() {}

  static string SolverTypeListString() {
    vector<string> solver_types = SolverTypeList();
    string solver_types_str;
    for (vector<string>::iterator iter = solver_types.begin();
         iter != solver_types.end(); ++iter) {
      if (iter != solver_types.begin()) {
        solver_types_str += ", ";
      }
      solver_types_str += *iter;
    }
    return solver_types_str;
  }
};


template <typename Dtype>
class SolverRegisterer {
 public:
  SolverRegisterer(const string& type,
		  	  	  	 // 指针函数
                     Solver<Dtype>* (*creator)(const SolverParameter&))
  {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry<Dtype>::AddCreator(type, creator);
  }
};

/*
分别定义了SolverRegisterer这个模板类的float和double类型的static变量,这会去调用各自
的构造函数,而在SolverRegisterer的构造函数中调用了之前提到的SolverRegistry类的
AddCreator函数,这个函数就是将刚才定义的Creator_SGDSolver这个函数的指针存到
g_registry指向的map里面。
*/
#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \
  static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \
/*
这个宏会定义一个名为Creator_SGDSolver的函数,这个函数即为Creator类型的指针指向的函数,
在这个函数中调用了SGDSolver的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator
类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了
REGISTER_SOLVER_CREATOR这个宏
*/
#define REGISTER_SOLVER_CLASS(type)                                            \
  template <typename Dtype>                                                    \
  Solver<Dtype>* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver<Dtype>(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

}  // namespace caffe

#endif  // CAFFE_SOLVER_FACTORY_H_

 

回顾命令行接口train()函数:

前面 caffe源码解析-命令行接口 一文的train()函数中有如下代码

  shared_ptr<caffe::Solver<float> > //初始化
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

理解了Solver工厂类,这条语句就很容易理解了。就是调用 caffe::SolverRegistry<float>::CreateSolver(solver_param) 得到一个Solver子类对象(如:SGDsolver),用智能指针solver指向它。

 

posted @ 2017-05-31 08:50  菜鸡一枚  阅读(492)  评论(0编辑  收藏  举报