Caffe源码学习笔记2:include/caffe/solver_factory.hpp
Caffe源码学习笔记2:include/caffe/solver_factory.hpp
简要说明:slover是什么?solver是caffe中实现训练模型参数更新的优化算法,solver类派生出的类可以对整个网络进行训练。在caffe中有很多solver子类,即不同的优化算法,如随机梯度下降(SGD)。
一个solver factory可以注册solvers,运行时,注册过的solvers通过SolverRegistry::CreateSolver(param)来调用。caffe提供两种方法注册一个solver,代码解析如下:
namespace caffe {
//声明Solver为模板类
template
class Solver;
//SolverRegistry 类不能实例化,所有的方法直接调用
template
class SolverRegistry {
public:
//Creator为函数指针类型;CreatorRegistry为标准map容器,储存函数指针;
typedef Solver* (*Creator)(const SolverParameter&);
typedef std::map CreatorRegistry;
//创建CreatorRegistry类型容器函数,返回其引用;
static CreatorRegistry& Registry() {
static CreatorRegistry* g_registry_ = new CreatorRegistry();
return *g_registry_;
}
// 向CreatorRegistry容器中增加Creator;
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;
}
// Get a solver using a SolverParameter.通过SolverParameter返回Solver指针;
static Solver* CreateSolver(const SolverParameter& param) {
const string& type = param.type();
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);
}
//获取CreatorRegistry容器中注册过的solver类型名,string列表储存;
static vector SolverTypeList() {
CreatorRegistry& registry = Registry();
vector solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter) {
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
SolverRegistry() {}
//这个函数从solver_types列表中取出一个个string;
static string SolverTypeListString() {
vector solver_types = SolverTypeList();
string solver_types_str;
for (vector::iterator iter = solver_types.begin();
iter != solver_types.end(); ++iter) {
if (iter != solver_types.begin()) {
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
template
class SolverRegisterer {
public:
//对SolverRegistry接口进行封装,功能是注册creator;
SolverRegisterer(const string& type,
Solver* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry::AddCreator(type, creator);
}
};
//注册方法一:注册一个solver creator
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer g_creator_f_##type(#type, creator); \
static SolverRegisterer g_creator_d_##type(#type, creator) \
//注册方法二:
#define REGISTER_SOLVER_CLASS(type) \
template \
Solver* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
} // namespace caffe
总结:通过两种方法之一注册一个solver creator,注册过后通过CreateSolver调用。