梳理caffe代码solver(十四)

梳理caffe代码solver(十四)

之前有一篇介绍solver的求解,也可以看官网的介绍:here ,和翻译版的介绍。

solver.hpp头文件的简单解析:

 

[cpp] view plain copy
 
  1. #ifndef CAFFE_SOLVER_HPP_  
  2. #define CAFFE_SOLVER_HPP_  
  3. #include <boost/function.hpp>  
  4. #include <string>  
  5. #include <vector>  
  6.   
  7. #include "caffe/net.hpp"  
  8. #include "caffe/solver_factory.hpp"  
  9.   
  10. namespace caffe {  
  11.   
  12. /** 
  13.   * @brief Enumeration of actions that a client of the Solver may request by 
  14.   * implementing the Solver's action request function, which a 
  15.   * a client may optionally provide in order to request early termination 
  16.   * or saving a snapshot without exiting. In the executable caffe, this 
  17.   * mechanism is used to allow the snapshot to be saved when stopping 
  18.   * execution with a SIGINT (Ctrl-C). 
  19.   */  
  20. //大概意思就是按Ctrl-C时,会保存当前训练时的模型,如果还在训练终端不小心被关闭时,可以接着上次继续训练  
  21.   namespace SolverAction {  
  22.     enum Enum {  
  23.       NONE = 0,  // Take no special action.  
  24.       STOP = 1,  //停止训练后,可以继续训练  
  25.       SNAPSHOT = 2  // Take a snapshot, and keep training.  
  26.     };  
  27.   }  
  28.   
  29. /** 
  30.  * @brief Type of a function that returns a Solver Action enumeration. 
  31.  */  
  32. //学过java的可以理解为回滚操作,比如银行账户钱从一个用户转到另一个账户时,中途发生点意外,一个用户钱已经减了,另一个却没有增加,这时需要回滚操作,  
  33. //就像这时训练的时候中断了,然后回滚,到上次断点,继续训练。  
  34. typedef boost::function<SolverAction::Enum()> ActionCallback;  
  35.   
  36. /** 
  37.  * @brief An interface for classes that perform optimization on Net%s. 
  38.  * 
  39.  * Requires implementation of ApplyUpdate to compute a parameter update 
  40.  * given the current state of the Net parameters. 
  41.  */  
  42. template <typename Dtype>  
  43. class Solver {  
  44.  public:  
  45.   explicit Solver(const SolverParameter& param,  
  46.       const Solver* root_solver = NULL);  
  47.   explicit Solver(const string& param_file, const Solver* root_solver = NULL);  
  48.   void Init(const SolverParameter& param);  
  49.   void InitTrainNet();  
  50.   void InitTestNets();  
  51.   
  52.   // Client of the Solver optionally may call this in order to set the function  
  53.   // that the solver uses to see what action it should take (e.g. snapshot or  
  54.   // exit training early).  
  55.   void SetActionFunction(ActionCallback func);  
  56.   SolverAction::Enum GetRequestedAction();  
  57.   //主函数,默认iter为0,非0的iter输入到预训练的网络中。  
  58.   virtual void Solve(const char* resume_file = NULL);  
  59.   inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }  
  60.   void Step(int iters);  
  61.   // The Restore method simply dispatches to one of the  
  62.   // RestoreSolverStateFrom___ protected methods. You should implement these  
  63.   // methods to restore the state from the appropriate snapshot type.  
  64.   //存储函数实现如何存储solver到快照模型中。应该实现RestoreSolverState()函数这个函数是存储来自SolverState缓冲的状态  
  65.   void Restore(const char* resume_file);  
  66.   
  67.   //Solver::Snapshot主要是基本的快照功能,存储学习的网络  
  68.   void Snapshot();  
  69.   virtual ~Solver() {}  
  70.   inline const SolverParameter& param() const { return param_; }  
  71.   inline shared_ptr<Net<Dtype> > net() { return net_; }  
  72.   inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {  
  73.     return test_nets_;  
  74.   }  
  75.   int iter() { return iter_; }  
  76.   //在迭代中调用特殊的点  
  77.   class Callback {  
  78.    protected:  
  79.     virtual void on_start() = 0;  
  80.     virtual void on_gradients_ready() = 0;  
  81.   
  82.     template <typename T>  
  83.     friend class Solver;  
  84.   };  
  85.   const vector<Callback*>& callbacks() const { return callbacks_; }  
  86.   void add_callback(Callback* value) {  
  87.     callbacks_.push_back(value);  
  88.   }  
  89.   
  90.   void CheckSnapshotWritePermissions();  
  91.   //返回slover的类型  
  92.   virtual inline const char* type() const { return ""; }  
  93.   
  94.  protected:  
  95.   //生成和应用当前迭代的更新的值  
  96.   virtual void ApplyUpdate() = 0;  
  97.   string SnapshotFilename(const string extension);  
  98.   string SnapshotToBinaryProto();  
  99.   string SnapshotToHDF5();  
  100.   // 测试程序  
  101.   void TestAll();  
  102.   void Test(const int test_net_id = 0);  
  103.   virtual void SnapshotSolverState(const string& model_filename) = 0;  
  104.   virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;  
  105.   virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;  
  106.   void DisplayOutputBlobs(const int net_id);  
  107.   void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);  
  108.   
  109.   SolverParameter param_;  
  110.   int iter_;//在测试的时候,需要迭代的次数,即test_iter* batchsize(测试集的)=测试集的大小,测试集batchsize可以在prototxt文件里设置    
  111.   int current_step_;  
  112.   shared_ptr<Net<Dtype> > net_;  
  113.   vector<shared_ptr<Net<Dtype> > > test_nets_;//test net可以有多个    
  114.   vector<Callback*> callbacks_;//嵌套类,暂时还不知道它的作用   
  115.   vector<Dtype> losses_;  
  116.   Dtype smoothed_loss_;  
  117.   
  118.   //在数据并行中,继续根solver层保持根nets(包含共享的层)  
  119.   const Solver* const root_solver_;  
  120.   //通过函数是选择确认按钮来选择保存还是退出快照。  
  121.   ActionCallback action_request_function_;  
  122.   
  123.   // True iff a request to stop early was received.  
  124.   bool requested_early_exit_;  
  125.   
  126.   DISABLE_COPY_AND_ASSIGN(Solver);  
  127. };  
  128.   
  129. //在多GPU计算时,仅仅计算梯度  
  130. template <typename Dtype>  
  131. class WorkerSolver : public Solver<Dtype> {  
  132.  public:  
  133.   explicit WorkerSolver(const SolverParameter& param,  
  134.       const Solver<Dtype>* root_solver = NULL)  
  135.       : Solver<Dtype>(param, root_solver) {}  
  136.   
  137.  protected:  
  138.   void ApplyUpdate() {}  
  139.   void SnapshotSolverState(const string& model_filename) {  
  140.     LOG(FATAL) << "Should not be called on worker solver.";  
  141.   }  
  142.   void RestoreSolverStateFromBinaryProto(const string& state_file) {  
  143.     LOG(FATAL) << "Should not be called on worker solver.";  
  144.   }  
  145.   void RestoreSolverStateFromHDF5(const string& state_file) {  
  146.     LOG(FATAL) << "Should not be called on worker solver.";  
  147.   }  
  148. };  
  149.   
  150. }  // namespace caffe  
  151.   
  152. #endif  // CAFFE_SOLVER_HPP_  

实现部分:

[cpp] view plain copy
 
    1. #include <cstdio>  
    2.   
    3. #include <string>  
    4. #include <vector>  
    5.   
    6. #include "caffe/solver.hpp"  
    7. #include "caffe/util/format.hpp"  
    8. #include "caffe/util/hdf5.hpp"  
    9. #include "caffe/util/io.hpp"  
    10. #include "caffe/util/upgrade_proto.hpp"  
    11.   
    12. namespace caffe {  
    13.   
    14. template<typename Dtype>  
    15. void Solver<Dtype>::SetActionFunction(ActionCallback func) {  
    16.   action_request_function_ = func;  
    17. }  
    18.   
    19. template<typename Dtype>  
    20. SolverAction::Enum Solver<Dtype>::GetRequestedAction() {  
    21.   if (action_request_function_) {  
    22.     // If the external request function has been set, call it.  
    23.     return action_request_function_();  
    24.   }  
    25.   return SolverAction::NONE;  
    26. }  
    27.   
    28. template <typename Dtype>  
    29. Solver<Dtype>::Solver(const SolverParameter& param, const Solver* root_solver)  
    30.     : net_(), callbacks_(), root_solver_(root_solver),  
    31.       requested_early_exit_(false) {  
    32.   Init(param);  
    33. }  
    34. //会调用Init()方法进行初始化,即Solver scaffolding   
    35. template <typename Dtype>  
    36. Solver<Dtype>::Solver(const string& param_file, const Solver* root_solver)  
    37.     : net_(), callbacks_(), root_solver_(root_solver),  
    38.       requested_early_exit_(false) {  
    39.   SolverParameter param;  
    40.   ReadSolverParamsFromTextFileOrDie(param_file, ¶m);  
    41.   Init(param);  
    42. }  
    43. /* 
    44. 功能:初始化网络 
    45. 步骤: 
    46. 1. 设置随机数种子 
    47. 2. 申请一块Net空间以下面的构造函数进行初始化 
    48. param_file=train_net_,net_指向这块空间 
    49. 3. 如果有test_net,则申请一块Net空间,test_net_指向这块空间 
    50. 输入:SolverParameter类型的param 
    51. 输出:无 
    52. */  
    53. template <typename Dtype>  
    54. void Solver<Dtype>::Init(const SolverParameter& param) {  
    55.   CHECK(Caffe::root_solver() || root_solver_)  
    56.       << "root_solver_ needs to be set for all non-root solvers";  
    57.   LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "  
    58.     << std::endl << param.DebugString();  
    59.   param_ = param;//为solver类的数据成员param_赋值  
    60.   CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";  
    61.   CheckSnapshotWritePermissions();  
    62.   if (Caffe::root_solver() && param_.random_seed() >= 0) {  
    63.     Caffe::set_random_seed(param_.random_seed());  
    64. //调用Caffe命名空间里的set_random_seed函数,而不是caffe类的set_random_seed函数;param_.random_seed()实际  
    65. //上调用的是::google::protobuf::int64 random_seed()    
    66.   }  
    67.   // Scaffolding code  
    68.   InitTrainNet();  
    69.   if (Caffe::root_solver()) {  
    70.     InitTestNets();  
    71.     LOG(INFO) << "Solver scaffolding done.";  
    72.   }  
    73.   iter_ = 0;  
    74.   current_step_ = 0;  
    75. }  
    76.   
    77. template <typename Dtype>  
    78. void Solver<Dtype>::InitTrainNet() {  
    79.   const int num_train_nets = param_.has_net() + param_.has_net_param() +  
    80.       param_.has_train_net() + param_.has_train_net_param();  
    81.   const string& field_names = "net, net_param, train_net, train_net_param";  
    82. //只能有一个train net  
    83.   CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "  
    84.       << "using one of these fields: " << field_names;  
    85.   CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "  
    86.       << "one of these fields specifying a train_net: " << field_names;  
    87.   NetParameter net_param;  
    88.   if (param_.has_train_net_param()) {  
    89.     LOG_IF(INFO, Caffe::root_solver())  
    90.         << "Creating training net specified in train_net_param.";  
    91.     net_param.CopyFrom(param_.train_net_param());  
    92.   } else if (param_.has_train_net()) {  
    93.     LOG_IF(INFO, Caffe::root_solver())  
    94.         << "Creating training net from train_net file: " << param_.train_net();  
    95.     ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);  
    96.   }  
    97.   if (param_.has_net_param()) {  
    98.     LOG_IF(INFO, Caffe::root_solver())  
    99.         << "Creating training net specified in net_param.";  
    100.     net_param.CopyFrom(param_.net_param());  
    101.   }  
    102.   if (param_.has_net()) {  
    103.     LOG_IF(INFO, Caffe::root_solver())  
    104.         << "Creating training net from net file: " << param_.net();  
    105.     ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);  
    106.   }  
    107.   //设置正确的网络状态,训练从默认开始,然后融入通过网络层规定在任何状态,最后融入训练状态(最优解)  
    108.   NetState net_state;  
    109.   net_state.set_phase(TRAIN);  
    110. //从低到高获取state,最终从最高优先级SolverParameter类型中的train_state,显然这会覆盖掉之前获取的state。    
    111.   net_state.MergeFrom(net_param.state());  
    112. //这里获取的state可以为Netparameter中的state赋值,然后可以根据LayerParameter中的include和exclude来确定该层是否应该包含在网络中。    
    113.   net_state.MergeFrom(param_.train_state());  
    114. //这是Initialize train net 的一部分工作。InitTestNets也是如此  
    115.   net_param.mutable_state()->CopyFrom(net_state);  
    116.   if (Caffe::root_solver()) {  
    117. //调用模板类的构造函数,进行net的初始化    
    118.     net_.reset(new Net<Dtype>(net_param));  
    119.   } else {  
    120.     net_.reset(new Net<Dtype>(net_param, root_solver_->net_.get()));  
    121.   }  
    122. }  
    123. //需要注意的是TestNet可以有多个,而TrainNet只能有一个   
    124. template <typename Dtype>  
    125. void Solver<Dtype>::InitTestNets() {  
    126.   CHECK(Caffe::root_solver());  
    127.   const bool has_net_param = param_.has_net_param();  
    128.   const bool has_net_file = param_.has_net();  
    129.   const int num_generic_nets = has_net_param + has_net_file;  
    130.   CHECK_LE(num_generic_nets, 1)  
    131.       << "Both net_param and net_file may not be specified.";  
    132.   const int num_test_net_params = param_.test_net_param_size();  
    133.   const int num_test_net_files = param_.test_net_size();  
    134.   const int num_test_nets = num_test_net_params + num_test_net_files;  
    135.   if (num_generic_nets) {  
    136.       CHECK_GE(param_.test_iter_size(), num_test_nets)  
    137.           << "test_iter must be specified for each test network.";  
    138.   } else {  
    139.       CHECK_EQ(param_.test_iter_size(), num_test_nets)  
    140.           << "test_iter must be specified for each test network.";  
    141.   }  
    142. //可以有多个test net  
    143.   const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;  
    144. //num_test_net_instances由num_test_nets 和 num_generic_net_instances 组成,实际上也就是param_.test_iter_size()   
    145.   const int num_test_net_instances = num_test_nets + num_generic_net_instances;  
    146.   if (param_.test_state_size()) {  
    147.     CHECK_EQ(param_.test_state_size(), num_test_net_instances)  
    148.         << "test_state must be unspecified or specified once per test net.";  
    149.   }  
    150.   if (num_test_net_instances) {  
    151.     CHECK_GT(param_.test_interval(), 0);  
    152.   }  
    153.   int test_net_id = 0;  
    154.   vector<string> sources(num_test_net_instances);  
    155.   vector<NetParameter> net_params(num_test_net_instances);  
    156.   for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {  
    157.       sources[test_net_id] = "test_net_param";  
    158.       net_params[test_net_id].CopyFrom(param_.test_net_param(i));  
    159.   }  
    160.   for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {  
    161.       sources[test_net_id] = "test_net file: " + param_.test_net(i);  
    162.       ReadNetParamsFromTextFileOrDie(param_.test_net(i),  
    163.           &net_params[test_net_id]);  
    164.   }  
    165.   const int remaining_test_nets = param_.test_iter_size() - test_net_id;  
    166.   if (has_net_param) {  
    167.     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {  
    168.       sources[test_net_id] = "net_param";  
    169.       net_params[test_net_id].CopyFrom(param_.net_param());  
    170.     }  
    171.   }  
    172.   if (has_net_file) {  
    173.     for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {  
    174.       sources[test_net_id] = "net file: " + param_.net();  
    175.       ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);  
    176.     }  
    177.   }  
    178.   test_nets_.resize(num_test_net_instances);  
    179.   for (int i = 0; i < num_test_net_instances; ++i) {  
    180. //设置正确的网络状态,训练从默认开始,然后融入通过网络层规定在任何状态,最后融入测试状态(最优解)  
    181.     NetState net_state;  
    182.     net_state.set_phase(TEST);  
    183.     net_state.MergeFrom(net_params[i].state());  
    184.     if (param_.test_state_size()) {  
    185.       net_state.MergeFrom(param_.test_state(i));  
    186.     }  
    187.     net_params[i].mutable_state()->CopyFrom(net_state);  
    188.     LOG(INFO)  
    189.         << "Creating test net (#" << i << ") specified by " << sources[i];  
    190.     if (Caffe::root_solver()) {  
    191.       test_nets_[i].reset(new Net<Dtype>(net_params[i]));  
    192.     } else {  
    193.       test_nets_[i].reset(new Net<Dtype>(net_params[i],  
    194.           root_solver_->test_nets_[i].get()));  
    195.     }  
    196.     test_nets_[i]->set_debug_info(param_.debug_info());  
    197.   }  
    198. }  
    199.   
    200. template <typename Dtype>  
    201. void Solver<Dtype>::Step(int iters) {  
    202.   const int start_iter = iter_;  
    203.   const int stop_iter = iter_ + iters;  
    204.   int average_loss = this->param_.average_loss();  
    205.   losses_.clear();  
    206.   smoothed_loss_ = 0;  
    207.   
    208.   while (iter_ < stop_iter) {  
    209.     // 0初始化参数  
    210.     net_->ClearParamDiffs();  
    211.     //test_initialization默认为true  
    212.     if (param_.test_interval() && iter_ % param_.test_interval() == 0  
    213.         && (iter_ > 0 || param_.test_initialization())  
    214.         && Caffe::root_solver()) {  
    215.       TestAll();  
    216.       if (requested_early_exit_) {  
    217.         // Break out of the while loop because stop was requested while testing.  
    218.         break;  
    219.       }  
    220.     }  
    221.   
    222.     for (int i = 0; i < callbacks_.size(); ++i) {  
    223.       callbacks_[i]->on_start();  
    224.     }  
    225.     const bool display = param_.display() && iter_ % param_.display() == 0;  
    226.     net_->set_debug_info(display && param_.debug_info());  
    227.     // accumulate the loss and gradient  
    228.     Dtype loss = 0;  
    229.     for (int i = 0; i < param_.iter_size(); ++i) {  
    230.       loss += net_->ForwardBackward();  
    231.     }  
    232.     loss /= param_.iter_size();//accumulate(累积) gradients over `iter_size` x `batch_size` instances。默认情况下,iter_size=1,即默认情况下,一个iteratio一个batch  
    233.     // average the loss across iterations for smoothed reporting  
    234.     UpdateSmoothedLoss(loss, start_iter, average_loss);  
    235.     // average_loss [default = 1]——> Display the loss averaged over the last average_loss iterations    
    236.     if (display) {  
    237.       LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_  
    238.           << ", loss = " << smoothed_loss_;  
    239.       const vector<Blob<Dtype>*>& result = net_->output_blobs();  
    240.       int score_index = 0;  
    241.       for (int j = 0; j < result.size(); ++j) {  
    242.         const Dtype* result_vec = result[j]->cpu_data();  
    243.         const string& output_name =  
    244.             net_->blob_names()[net_->output_blob_indices()[j]];  
    245.         const Dtype loss_weight =  
    246.             net_->blob_loss_weights()[net_->output_blob_indices()[j]];  
    247.         for (int k = 0; k < result[j]->count(); ++k) {  
    248.           ostringstream loss_msg_stream;  
    249.           if (loss_weight) {  
    250.             loss_msg_stream << " (* " << loss_weight  
    251.                             << " = " << loss_weight * result_vec[k] << " loss)";  
    252.           }  
    253.           LOG_IF(INFO, Caffe::root_solver()) << "    Train net output #"  
    254.               << score_index++ << ": " << output_name << " = "  
    255.               << result_vec[k] << loss_msg_stream.str();  
    256.         }  
    257.       }  
    258.     }  
    259.     for (int i = 0; i < callbacks_.size(); ++i) {  
    260.       callbacks_[i]->on_gradients_ready();  
    261.     }  
    262.     ApplyUpdate();  
    263.   
    264.     // Increment the internal iter_ counter -- its value should always indicate  
    265.     // the number of times the weights have been updated.  
    266.     ++iter_;  
    267.   
    268.     SolverAction::Enum request = GetRequestedAction();  
    269.   
    270.     // Save a snapshot if needed.  
    271.     if ((param_.snapshot()  
    272.          && iter_ % param_.snapshot() == 0  
    273.          && Caffe::root_solver()) ||  
    274.          (request == SolverAction::SNAPSHOT)) {  
    275.       Snapshot();  
    276.     }  
    277.     if (SolverAction::STOP == request) {  
    278.       requested_early_exit_ = true;  
    279.       // Break out of training loop.  
    280.       break;  
    281.     }  
    282.   }  
    283. }  
    284. /* 
    285. 对整个网络进行训练(也就是你运行Caffe训练某个模型)的时候,实际上是在运行caffe.cpp中的train()函数,而这个函数实际上是实例化一个Solver对象,初始化后调用了Solver中的Solve()方法   
    286. 调用此方法训练网络,其中会调用Step()方法来迭代,迭代 param_.max_iter() - iter_ 次 
    287. */  
    288. template <typename Dtype>  
    289. void Solver<Dtype>::Solve(const char* resume_file) {  
    290.   CHECK(Caffe::root_solver());  
    291.   LOG(INFO) << "Solving " << net_->name();  
    292.   LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();  
    293.   
    294.   //任何时候开始求解,初始化失败  
    295.   requested_early_exit_ = false;  
    296.   
    297.   if (resume_file) {  
    298.     LOG(INFO) << "Restoring previous solver status from " << resume_file;  
    299.     Restore(resume_file);  
    300.   }  
    301.   
    302.   // For a network that is trained by the solver, no bottom or top vecs  
    303.   // should be given, and we will just provide dummy vecs.  
    304.   //对于一个正在训练的网络,没有bottom或top向量被给,而且仅仅提供dummy vecs  
    305.   int start_iter = iter_;  
    306.   Step(param_.max_iter() - iter_);  
    307.   // If we haven't already, save a snapshot after optimization, unless  
    308.   // overridden by setting snapshot_after_train := false  
    309.   if (param_.snapshot_after_train()  
    310.       && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {  
    311.     Snapshot();  
    312.   }  
    313.   if (requested_early_exit_) {  
    314.     LOG(INFO) << "Optimization stopped early.";  
    315.     return;  
    316.   }  
    317.   //在优化完后,运行一个额外的训练和测试过程展示训练测试的loss或者输出。  
    318.   if (param_.display() && iter_ % param_.display() == 0) {  
    319.     int average_loss = this->param_.average_loss();  
    320.     Dtype loss;  
    321.     net_->Forward(&loss);  
    322.   
    323.     UpdateSmoothedLoss(loss, start_iter, average_loss);  
    324.   
    325.     LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;  
    326.   }  
    327.   if (param_.test_interval() && iter_ % param_.test_interval() == 0) {  
    328.     TestAll();  
    329.   }  
    330.   LOG(INFO) << "Optimization Done.";  
    331. }  
    332.   
    333. template <typename Dtype>  
    334. void Solver<Dtype>::TestAll() {  
    335.   for (int test_net_id = 0;  
    336.        test_net_id < test_nets_.size() && !requested_early_exit_;  
    337.        ++test_net_id) {  
    338.     Test(test_net_id);  
    339.   }  
    340. }  
    341.   
    342. template <typename Dtype>  
    343. void Solver<Dtype>::Test(const int test_net_id) {  
    344.   CHECK(Caffe::root_solver());  
    345.   LOG(INFO) << "Iteration " << iter_  
    346.             << ", Testing net (#" << test_net_id << ")";  
    347.   //检查是否有layer共享于多个网络  
    348.   CHECK_NOTNULL(test_nets_[test_net_id].get())->  
    349.       ShareTrainedLayersWith(net_.get());  
    350.   vector<Dtype> test_score;  
    351.   vector<int> test_score_output_id;  
    352.   const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];  
    353.   Dtype loss = 0;  
    354.   for (int i = 0; i < param_.test_iter(test_net_id); ++i) {  
    355.     SolverAction::Enum request = GetRequestedAction();  
    356.    
    357.     //如果在训练或测试中断请求发出后,随时执行保存快照  
    358.     while (request != SolverAction::NONE) {  
    359.         if (SolverAction::SNAPSHOT == request) {  
    360.           Snapshot();  
    361.         } else if (SolverAction::STOP == request) {  
    362.           requested_early_exit_ = true;  
    363.         }  
    364.         request = GetRequestedAction();  
    365.     }  
    366.     if (requested_early_exit_) {  
    367.       // break out of test loop.  
    368.       break;  
    369.     }  
    370.   
    371.     Dtype iter_loss;  
    372.     const vector<Blob<Dtype>*>& result =  
    373.         test_net->Forward(&iter_loss);  
    374.     if (param_.test_compute_loss()) {  
    375.       loss += iter_loss;  
    376.     }  
    377.     if (i == 0) {  
    378.       for (int j = 0; j < result.size(); ++j) {  
    379.         const Dtype* result_vec = result[j]->cpu_data();  
    380.         for (int k = 0; k < result[j]->count(); ++k) {  
    381.           test_score.push_back(result_vec[k]);  
    382.           test_score_output_id.push_back(j);  
    383.         }  
    384.       }  
    385.     } else {  
    386.       int idx = 0;  
    387.       for (int j = 0; j < result.size(); ++j) {  
    388.         const Dtype* result_vec = result[j]->cpu_data();  
    389.         for (int k = 0; k < result[j]->count(); ++k) {  
    390.           test_score[idx++] += result_vec[k];  
    391.         }  
    392.       }  
    393.     }  
    394.   }  
    395.   if (requested_early_exit_) {  
    396.     LOG(INFO)     << "Test interrupted.";  
    397.     return;  
    398.   }  
    399.   if (param_.test_compute_loss()) {  
    400.     loss /= param_.test_iter(test_net_id);  
    401.     LOG(INFO) << "Test loss: " << loss;  
    402.   }  
    403.   for (int i = 0; i < test_score.size(); ++i) {  
    404.     const int output_blob_index =  
    405.         test_net->output_blob_indices()[test_score_output_id[i]];  
    406.     const string& output_name = test_net->blob_names()[output_blob_index];  
    407.     const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];  
    408.     ostringstream loss_msg_stream;  
    409.     const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);//求多次迭代Loss的平均值,也就是求多个batch的平局值,因为一次迭代用的是一个test batch-size 的图片  
    410.     if (loss_weight) {  
    411.       loss_msg_stream << " (* " << loss_weight  
    412.                       << " = " << loss_weight * mean_score << " loss)";  
    413.     }  
    414.     LOG(INFO) << "    Test net output #" << i << ": " << output_name << " = "  
    415.               << mean_score << loss_msg_stream.str();  
    416.   }  
    417. }  
    418. //输出当前网络状态到一个文件中。   
    419. template <typename Dtype>  
    420. void Solver<Dtype>::Snapshot() {  
    421.   CHECK(Caffe::root_solver());  
    422.   string model_filename;  
    423.   switch (param_.snapshot_format()) {  
    424.   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:  
    425.     model_filename = SnapshotToBinaryProto();  
    426.     break;  
    427.   case caffe::SolverParameter_SnapshotFormat_HDF5:  
    428.     model_filename = SnapshotToHDF5();  
    429.     break;  
    430.   default:  
    431.     LOG(FATAL) << "Unsupported snapshot format.";  
    432.   }  
    433.   
    434.   SnapshotSolverState(model_filename);  
    435. }  
    436. //check快照的写入权限  
    437. template <typename Dtype>  
    438. void Solver<Dtype>::CheckSnapshotWritePermissions() {  
    439.   if (Caffe::root_solver() && param_.snapshot()) {  
    440.     CHECK(param_.has_snapshot_prefix())  
    441.         << "In solver params, snapshot is specified but snapshot_prefix is not";  
    442.     string probe_filename = SnapshotFilename(".tempfile");  
    443.     std::ofstream probe_ofs(probe_filename.c_str());  
    444.     if (probe_ofs.good()) {  
    445.       probe_ofs.close();  
    446.       std::remove(probe_filename.c_str());  
    447.     } else {  
    448.       LOG(FATAL) << "Cannot write to snapshot prefix '"  
    449.           << param_.snapshot_prefix() << "'.  Make sure "  
    450.           << "that the directory exists and is writeable.";  
    451.     }  
    452.   }  
    453. }  
    454. //Snapshot的名字  
    455. template <typename Dtype>  
    456. string Solver<Dtype>::SnapshotFilename(const string extension) {  
    457.   return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)  
    458.     + extension;  
    459. }  
    460. //Snapshot保存为二进制proto的模型  
    461. template <typename Dtype>  
    462. string Solver<Dtype>::SnapshotToBinaryProto() {  
    463.   string model_filename = SnapshotFilename(".caffemodel");  
    464.   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;  
    465.   NetParameter net_param;  
    466.   net_->ToProto(&net_param, param_.snapshot_diff());  
    467.   WriteProtoToBinaryFile(net_param, model_filename);  
    468.   return model_filename;  
    469. }  
    470. //Snapshot保存为HDF5模型  
    471. template <typename Dtype>  
    472. string Solver<Dtype>::SnapshotToHDF5() {  
    473.   string model_filename = SnapshotFilename(".caffemodel.h5");  
    474.   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;  
    475.   net_->ToHDF5(model_filename, param_.snapshot_diff());  
    476.   return model_filename;  
    477. }  
    478. //从一个文件中读入网络状态,并可以从那个状态恢复。   
    479. template <typename Dtype>  
    480. void Solver<Dtype>::Restore(const char* state_file) {  
    481.   CHECK(Caffe::root_solver());  
    482.   string state_filename(state_file);  
    483.   if (state_filename.size() >= 3 &&  
    484.       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {  
    485.     RestoreSolverStateFromHDF5(state_filename);  
    486.   } else {  
    487.     RestoreSolverStateFromBinaryProto(state_filename);  
    488.   }  
    489. }  
    490. //迭代时平均loss的smooth报告,翻译不是很准     
    491. template <typename Dtype>  
    492. void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,  
    493.     int average_loss) {  
    494.   if (losses_.size() < average_loss) {  
    495.     losses_.push_back(loss);  
    496.     int size = losses_.size();  
    497.     smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;  
    498.   } else {  
    499.     int idx = (iter_ - start_iter) % average_loss;  
    500.     smoothed_loss_ += (loss - losses_[idx]) / average_loss;  
    501.     losses_[idx] = loss;  
    502.   }  
    503. }  
    504.   
    505. INSTANTIATE_CLASS(Solver);  
    506.   
    507. }  // namespace caffe  
posted @ 2017-07-11 16:00  菜鸡一枚  阅读(419)  评论(0编辑  收藏  举报