caffe源码解析-solver_factory
caffe源码解析-solver_factory
声明:内容整理自
感谢Ldy和各位博主的无私分享。各位博主已经写的很好,个人做了一些梳理和补充,方便日后回顾。
目录:
- SolverRegistry 工厂类定义
- SolverRegistry 工厂类注册Solver子类
- 回顾命令行接口中train()函数
caffe提供了很多优化方法。它们被封装成继承自父类Solver的子类。如SGDSolver,AdamSolver等。为了能够灵活的实例化这些子类,采用了下图的工厂模式。(关于工厂模式可以参考 工厂模式 | 菜鸟教程)
先来研究一下工厂类 SolverRegistry。代码主要在include/caffe/solver_factory.hpp文件中。
SolverRegistry 工厂类定义:
我们知道工厂类的关键是一个可以实现选择的结构,它根据关键词来实例化相应的子类并返回。例如下图 ShapeFactory就是根据传入的字符串实例化不同的形状子类并返回。
但是代码并没有直接使用上面的分支结构实现选择。而是用了一个map容器,它以"特定Solver的type"为Key,以”特定Solver对应的Creator函数指针"为Value,通过这种键值映射关系也可以实现选择。 (本质上是用查表来替代分支实现选择,空间换时间)
1. map容器:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
typedef std::map<string, Creator> CreatorRegistry;
Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>* 。这个函数就是 ”特定Solver对应的Creator函数"。调用这个函数可以就得到实例化的子类。 (注意:容器的类型是被定义为CreatorRegistry)。
PS:为什么把响应的对象直接作为map容器的 Value值呢,这样岂不很直接? 个人理解是因为对象实例化还需要参数,用一个函数封装一下方便实例化时传参。
static CreatorRegistry& Registry() {
//静态变量
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
CreatorRegistry& registry = Registry();
Registry的作用是,调用时返回一个指向 容器(CreatorRegistry)类型的静态变量 g_registry_。说白了g_registry_就是一个容器指针。因为这个变量是static的,所以即使多次调用这个函数,也只会得到一个g_registry_,而且在其他地方修改这个map里的内容,是存储在这个map中的。(看来的确是容器)
接下来的重点就是如何设计 “注册器”(把creator函数注册进map容器),“提取器”(从map容器中根据type提取creator函数)
2. “注册器”AddCreator如下:
// Adds a creator.
// 添加一个creator
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
先调用Registry()得到容器,确保容器内不存在要添加的type关键字,然后添加Key为type,Value为Creator函数的元素。
3. "提取器"CreateSolver函数如下:
// Get a solver using a SolverParameter.
// 静态成员函数,在caffe.cpp里直接调用,返回Solver指针
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
// string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等)
// 默认为SGD
const string& type = param.type();
// 定义了一个key类型为string,value类型为Creator的map:registry
// 返回为静态变量
CreatorRegistry& registry = Registry();
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter)
{
std::cout<<"key:"<<iter->first<<"``` "
<<"value:"<<iter->second<<std::endl;}
/*
* 如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,
* 然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个
* creator函数,将creator返回的Solver<Dtype>*返回。
*/
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
//通过static的g_registry_[type]获得type对应的solver的creator函数指针
return registry[type](param);
}
这个函数先定义了string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等),然后定义了一个key类型为string,value类型为Creator的map:registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回,也就是说调用”提取器“就可以得到实例化之后的特定Solver对象。(果然是工厂)
SolverRegistry 工厂类 注册Solver子类:
在include/caffe/solver_factory.hpp文件的核心内容就是上面的工厂类 SolverRegistry。接下来说说如何进行”注册“。注册不就是直接调用AddCreator注册器吗,有什么好讲的。非也非也。我们不妨先按照自己的想法试一试该怎么做(下面是个人理解,如有错误还望指正)。
”注册器“一般在Solver子类(如 SGDsolver)的定义之后调用,以SGDsolver为例。首先,必须封装一个Creator函数,然后调用AddCreator。
Solver<Dtype>* Creator_SGDSolver( const SolverParameter& param)
{
return new SGDSolver<Dtype>(param);
}
SolverRegistry<Dtype>::AddCreator('SGD', Creator_SGDSolver<float>);
SolverRegistry<Dtype>::AddCreator('SGD', Creator_SGDSolver<double>);
每一个Solver子类定义的后面都要加上类似的代码,对于Adadelta子类则变成:
Solver<Dtype>* Creator_AdadeltaSolver( const SolverParameter& param)
{
return new AdadeltaSolver<Dtype>(param);
}
SolverRegistry<Dtype>::AddCreator('Adadelta',Creator_AdadeltaSolver<float>);
SolverRegistry<Dtype>::AddCreator('Adadelta',Creator_AdadeltaSolver<double>);
这两段代码唯一的区别是 "SGD" 和 ”Adadelta“,所以自然想到要用宏来简化代码。
注册一个Solver子类就只需要一条语句:
REGISTER_SOLVER_CLASS(SGD);
不过源代码更加精简:把
SolverRegistry<Dtype>::AddCreator(#type, Creator_##type##Solver<float>); \
SolverRegistry<Dtype>::AddCreator(#type, Creator_##type##Solver<double>)
这两条语句也封装成了一个宏:
先定义了一个SolverRegisterer模板类,它只有一个构造函数,用来调用SolverRegistry<Dtype>::AddCreator。这样做的目的是: ”声明变量“ 触发 构造函数调用,进而”调用AddCreator“。也就是说现在你只需要声明一个 SolverRegisterer类的变量就会取调用AddCreator。
有了这一招,下面就可以定义REGISTER_SOLVER_CREATOR宏。 它分别定义了SolverRegisterer类的float和double类型的static变量(”声明变量“),这会 触发 调用各自的构造函数,而在SolverRegisterer的构造函数中”调用AddCreator"。
所以最终的注册器由原来的:
变成了:
至此我们已经把solver_factory.hpp中的所有代码讲解完毕。
include\caffe\solver_factory.hpp的全部代码
如下:
#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
template <typename Dtype>
class Solver;
template <typename Dtype>
class SolverRegistry {
public:
//Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型
//,返回类型为Solver<Dtype>*
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
//
typedef std::map<string, Creator> CreatorRegistry;
static CreatorRegistry& Registry() {
//静态变量
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
// Adds a creator.
// 添加一个creator
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
// Get a solver using a SolverParameter.
// 静态成员函数,在caffe.cpp里直接调用,返回Solver指针
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
// string类型的变量type,表示Solver的类型(‘SGD’/’Nestrov’等)
// 默认为SGD
const string& type = param.type();
// 定义了一个key类型为string,value类型为Creator的map:registry
// 返回为静态变量
CreatorRegistry& registry = Registry();
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter)
{
std::cout<<"key:"<<iter->first<<"``` "
<<"value:"<<iter->second<<std::endl;}
/*
* 如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,
* 然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个
* creator函数,将creator返回的Solver<Dtype>*返回。
*/
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
//通过static的g_registry_[type]获得type对应的solver的creator函数指针
return registry[type](param