** "mindspore\lite\examples\transfer_learning\src\net_runner.cc"注释2**

一、代码用处

这段代码对数据库初始化,以及精确度的计算还有缺失值的fill以及处理,最后是main方法来运行

二、代码注释

float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const {//计算准确性
  float accuracy = 0.0;//初始化准确性
  int tests = dataset.size() / batch_size_;

  session_->Eval();
  for (int i = 0; i < tests; i++) {//对数据集里的数据进行遍历
    auto labels = FillInputData(dataset, i);
    session_->RunGraph();
    auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
    MS_ASSERT(outputsv != nullptr);
    auto scores = reinterpret_cast<float *>(outputsv->MutableData());
    for (int b = 0; b < batch_size_; b++) {
      int max_idx = 0;
      float max_score = scores[num_of_classes_ * b]; 
      for (int c = 0; c < num_of_classes_; c++) {
        if (scores[num_of_classes_ * b + c] > max_score) {
          max_score = scores[num_of_classes_ * b + c];//遍历得到最大的值
          max_idx = c;
        }
      }
      if (labels[b] == max_idx) accuracy += 1.0;//每当数据标签等于最大值,精确的加1
    }
  }
  session_->Train();
  accuracy /= static_cast<float>(batch_size_ * tests);//转换类型
  return accuracy;//返回精确度
}

int NetRunner::InitDB() {//初始化数据库函数
  if (data_size_ != 0) ds_.set_expected_data_size(data_size_);//导入数据
  int ret = ds_.Init(data_dir_, DS_OTHER);//将ret初始化
  num_of_classes_ = ds_.num_of_classes();
  if (verbose_) {
    std::cout << "dataset train/test/val size is:" << ds_.train_data().size() << "/" << ds_.test_data().size() << "/"
              << ds_.val_data().size() << std::endl;//分别输出数据库的 train/test/val size
  }

  if (ds_.test_data().size() == 0) {//判断是否有相应数据
    std::cout << "No relevant data was found in " << data_dir_ << std::endl;
    MS_ASSERT(ds_.test_data().size() != 0);//判断数据是否初始化成功
  }

  return ret;//返回函数ret
}

float NetRunner::GetLoss() const {//获取损失值函数
  auto outputsv = SearchOutputsForSize(1);  //搜索是单个的损失值
  MS_ASSERT(outputsv != nullptr);
  auto loss = reinterpret_cast<float *>(outputsv->MutableData());//转换类型
  return loss[0];//返回损失的第一个值
}

int NetRunner::TrainLoop() {
  session_->Train();//Train函数
  float min_loss = 1000.;
  float max_acc = 0.;
  for (int i = 0; i < cycles_; i++) {
    FillInputData(ds_.train_data());//填充输入数据
    session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
    float loss = GetLoss();
    if (min_loss > loss) min_loss = loss;//判断loss的最小值是否大于loss,否则交换

    if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
      auto cpkt_fn =
        ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
      mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str());
    }

    std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;//输出缺失值
    if ((i + 1) % 20 == 0) {
      float acc = CalculateAccuracy(ds_.test_data());//计算精确度
      if (max_acc < acc) max_acc = acc;
      std::cout << "accuracy on test data = " << acc << " max accuracy = " << max_acc << std::endl;//输出精确度
      if (acc > 0.9) return 0;//精确度为0.9的时候就返回0
    }
  }
  return 0;
}

int NetRunner::Main() {//主函数,用来执行数据库运行操作
  InitAndFigureInputs();

  InitDB();

  TrainLoop();

  float acc = CalculateAccuracy(ds_.val_data());//输出 validation data精确度
  std::cout << "accuracy on validation data = " << acc << std::endl;

  if (cycles_ > 0 && head_model_ != nullptr) {
    auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
    mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
  }
  return 0;
}

void NetRunner::Usage() {
  std::cout << "Usage: net_runner -f <.ms head model file> -b <.ms backbone model file> -d <data_dir> "
            << "[-c <num of training cycles>] [-v (verbose mode)] "
            << "[-s <save checkpoint every X iterations>]" << std::endl;
}

bool NetRunner::ReadArgs(int argc, char *argv[]) {//读取参数
  int opt;
  while ((opt = getopt(argc, argv, "b:f:e:d:s:ihc:v")) != -1) {//getopt()用来分析命令行参数。参数argc和argv分别代表参数个数和内容,跟main()函数的命令行参数是一样的。参数 optstring为选项字符串, 告知 getopt()可以处理哪个选项以及哪个选项需要参数
    switch (opt) {
      case 'b':
        ms_backbone_file_ = std::string(optarg);//不同情况分别赋值
        break;
      case 'f':
        ms_head_file_ = std::string(optarg);
        break;
      case 'e':
        cycles_ = atoi(optarg);
        break;
      case 'd':
        data_dir_ = std::string(optarg);
        break;
      case 'v':
        verbose_ = true;
        break;
      case 's':
        save_checkpoint_ = atoi(optarg);
        break;
      case 'h':
      default://终结循环
        Usage();
        return false;
    }
  }
  return true;//返回true值
}

int main(int argc, char **argv) {//主函数
  NetRunner nr;

  if (nr.ReadArgs(argc, argv)) {看读取参数是否成功
    nr.Main();//运行主函数
  } else {
    return -1;
  }
  return 0;
}