** "mindspore\lite\examples\transfer_learning\src\net_runner.cc"注释2**
一、代码用处
这段代码对数据库初始化,以及精确度的计算还有缺失值的fill以及处理,最后是main方法来运行
二、代码注释
float NetRunner::CalculateAccuracy(const std::vector<DataLabelTuple> &dataset) const {//计算准确性
float accuracy = 0.0;//初始化准确性
int tests = dataset.size() / batch_size_;
session_->Eval();
for (int i = 0; i < tests; i++) {//对数据集里的数据进行遍历
auto labels = FillInputData(dataset, i);
session_->RunGraph();
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
MS_ASSERT(outputsv != nullptr);
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
for (int b = 0; b < batch_size_; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes_ * b];
for (int c = 0; c < num_of_classes_; c++) {
if (scores[num_of_classes_ * b + c] > max_score) {
max_score = scores[num_of_classes_ * b + c];//遍历得到最大的值
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;//每当数据标签等于最大值,精确的加1
}
}
session_->Train();
accuracy /= static_cast<float>(batch_size_ * tests);//转换类型
return accuracy;//返回精确度
}
int NetRunner::InitDB() {//初始化数据库函数
if (data_size_ != 0) ds_.set_expected_data_size(data_size_);//导入数据
int ret = ds_.Init(data_dir_, DS_OTHER);//将ret初始化
num_of_classes_ = ds_.num_of_classes();
if (verbose_) {
std::cout << "dataset train/test/val size is:" << ds_.train_data().size() << "/" << ds_.test_data().size() << "/"
<< ds_.val_data().size() << std::endl;//分别输出数据库的 train/test/val size
}
if (ds_.test_data().size() == 0) {//判断是否有相应数据
std::cout << "No relevant data was found in " << data_dir_ << std::endl;
MS_ASSERT(ds_.test_data().size() != 0);//判断数据是否初始化成功
}
return ret;//返回函数ret
}
float NetRunner::GetLoss() const {//获取损失值函数
auto outputsv = SearchOutputsForSize(1); //搜索是单个的损失值
MS_ASSERT(outputsv != nullptr);
auto loss = reinterpret_cast<float *>(outputsv->MutableData());//转换类型
return loss[0];//返回损失的第一个值
}
int NetRunner::TrainLoop() {
session_->Train();//Train函数
float min_loss = 1000.;
float max_acc = 0.;
for (int i = 0; i < cycles_; i++) {
FillInputData(ds_.train_data());//填充输入数据
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
float loss = GetLoss();
if (min_loss > loss) min_loss = loss;//判断loss的最小值是否大于loss,否则交换
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
auto cpkt_fn =
ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
mindspore::lite::Model::Export(head_model_, cpkt_fn.c_str());
}
std::cout << i + 1 << ": Loss is " << loss << " [min=" << min_loss << "]" << std::endl;//输出缺失值
if ((i + 1) % 20 == 0) {
float acc = CalculateAccuracy(ds_.test_data());//计算精确度
if (max_acc < acc) max_acc = acc;
std::cout << "accuracy on test data = " << acc << " max accuracy = " << max_acc << std::endl;//输出精确度
if (acc > 0.9) return 0;//精确度为0.9的时候就返回0
}
}
return 0;
}
int NetRunner::Main() {//主函数,用来执行数据库运行操作
InitAndFigureInputs();
InitDB();
TrainLoop();
float acc = CalculateAccuracy(ds_.val_data());//输出 validation data精确度
std::cout << "accuracy on validation data = " << acc << std::endl;
if (cycles_ > 0 && head_model_ != nullptr) {
auto trained_fn = ms_head_file_.substr(0, ms_head_file_.find_last_of('.')) + "_trained.ms";
mindspore::lite::Model::Export(head_model_, trained_fn.c_str());
}
return 0;
}
void NetRunner::Usage() {
std::cout << "Usage: net_runner -f <.ms head model file> -b <.ms backbone model file> -d <data_dir> "
<< "[-c <num of training cycles>] [-v (verbose mode)] "
<< "[-s <save checkpoint every X iterations>]" << std::endl;
}
bool NetRunner::ReadArgs(int argc, char *argv[]) {//读取参数
int opt;
while ((opt = getopt(argc, argv, "b:f:e:d:s:ihc:v")) != -1) {//getopt()用来分析命令行参数。参数argc和argv分别代表参数个数和内容,跟main()函数的命令行参数是一样的。参数 optstring为选项字符串, 告知 getopt()可以处理哪个选项以及哪个选项需要参数
switch (opt) {
case 'b':
ms_backbone_file_ = std::string(optarg);//不同情况分别赋值
break;
case 'f':
ms_head_file_ = std::string(optarg);
break;
case 'e':
cycles_ = atoi(optarg);
break;
case 'd':
data_dir_ = std::string(optarg);
break;
case 'v':
verbose_ = true;
break;
case 's':
save_checkpoint_ = atoi(optarg);
break;
case 'h':
default://终结循环
Usage();
return false;
}
}
return true;//返回true值
}
int main(int argc, char **argv) {//主函数
NetRunner nr;
if (nr.ReadArgs(argc, argv)) {看读取参数是否成功
nr.Main();//运行主函数
} else {
return -1;
}
return 0;
}