Caffe源码学习笔记2:include/caffe/solver_factory.hpp

Caffe源码学习笔记2:include/caffe/solver_factory.hpp

简要说明:slover是什么?solver是caffe中实现训练模型参数更新的优化算法,solver类派生出的类可以对整个网络进行训练。在caffe中有很多solver子类,即不同的优化算法,如随机梯度下降(SGD)。
一个solver factory可以注册solvers,运行时,注册过的solvers通过SolverRegistry::CreateSolver(param)来调用。caffe提供两种方法注册一个solver,代码解析如下:

namespace caffe {
//声明Solver为模板类
template
class Solver;

//SolverRegistry 类不能实例化,所有的方法直接调用
template
class SolverRegistry {
 public:
  //Creator为函数指针类型;CreatorRegistry为标准map容器,储存函数指针;
  typedef Solver* (*Creator)(const SolverParameter&);
  typedef std::map CreatorRegistry;
 
  //创建CreatorRegistry类型容器函数,返回其引用;
  static CreatorRegistry& Registry() {
    static CreatorRegistry* g_registry_ = new CreatorRegistry();
    return *g_registry_;
  }

  // 向CreatorRegistry容器中增加Creator;
  static void AddCreator(const string& type, Creator creator) {
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 0)
        << "Solver type " << type << " already registered.";
    registry[type] = creator;
  }

  // Get a solver using a SolverParameter.通过SolverParameter返回Solver指针;
  static Solver* CreateSolver(const SolverParameter& param) {
    const string& type = param.type();
    CreatorRegistry& registry = Registry();
    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
        << " (known types: " << SolverTypeListString() << ")";
    return registry[type](param);
  }
 
  //获取CreatorRegistry容器中注册过的solver类型名,string列表储存;
  static vector SolverTypeList() {
    CreatorRegistry& registry = Registry();
    vector solver_types;
    for (typename CreatorRegistry::iterator iter = registry.begin();
         iter != registry.end(); ++iter) {
      solver_types.push_back(iter->first);
    }
    return solver_types;
  }

private:
  SolverRegistry() {}

  //这个函数从solver_types列表中取出一个个string;
  static string SolverTypeListString() {
    vector solver_types = SolverTypeList();
    string solver_types_str;
    for (vector::iterator iter = solver_types.begin();
         iter != solver_types.end(); ++iter) {
      if (iter != solver_types.begin()) {
        solver_types_str += ", ";
      }
      solver_types_str += *iter;
    }
    return solver_types_str;
  }
};


template
class SolverRegisterer {
 public:
  //对SolverRegistry接口进行封装,功能是注册creator;
  SolverRegisterer(const string& type,
      Solver* (*creator)(const SolverParameter&)) {
    // LOG(INFO) << "Registering solver type: " << type;
    SolverRegistry::AddCreator(type, creator);
  }
};

//注册方法一:注册一个solver creator
#define REGISTER_SOLVER_CREATOR(type, creator)                                 \
  static SolverRegisterer g_creator_f_##type(#type, creator);    \
  static SolverRegisterer g_creator_d_##type(#type, creator)   \
 
 //注册方法二:
#define REGISTER_SOLVER_CLASS(type)                                            \
  template                                                     \
  Solver* Creator_##type##Solver(                                       \
      const SolverParameter& param)                                            \
  {                                                                            \
    return new type##Solver(param);                                     \
  }                                                                            \
  REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

}  // namespace caffe

总结:通过两种方法之一注册一个solver creator,注册过后通过CreateSolver调用。

 

posted @ 2017-05-06 11:08  菜鸡一枚  阅读(277)  评论(0编辑  收藏  举报