代码注释
//设置图像输出的索引值函数
int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node) {
MS_ASSERT(nullptr != meta_graphT);
MS_ASSERT(nullptr != return_node);//根据阔号的条件判断程序是否执行下去
for (size_t i = 1; i < cnode->inputs().size(); i++) {//以输入的数量为样本
auto input_node = cnode->input(i);
if (input_node == nullptr) {//判断输入数据的有效性,进行筛选
MS_LOG(ERROR) << "output node is nullptr";//发生错误则返回
return RET_NULL_PTR;
} else if (input_node->isa<CNode>()) {
auto ret = ConvertInputCNode(input_node, return_node);//对输入数据进行转换
if (ret != RET_OK) {//判断转换是否成功
MS_LOG(ERROR) << "obtain outputs failed";
return ret;
}
} else if (input_node->isa<Parameter>()) {//判断输入的是否为参数
MS_LOG(INFO) << "the node " << input_node->fullname_with_scope().c_str() << "is parameter node";
continue;
} else {//若不是,则输出错误
MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
return RET_ERROR;
}
}
for (unsigned int &i : return_node->inputIndex) {//对通过测试的输入索引进行遍历
if (subgraph_index == kMainGraphIndex) {判断图片节点是否正确
auto &tensor = meta_graphT->allTensors.at(i);
TensorDataType::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType);//转换类型
meta_graphT->outputIndex.push_back(i);//对output index进行添加i
}
meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
}
return RET_OK;
}
//bool函数用来判断是否输出
bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) {
if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) }//如果这两个变量不等,就返回true
return true;
}
return false;
}
int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive) {
int ret = RET_OK;//初始化ret
auto cnodes = GetOrderedCNodes(func_graph);//获取函数图片的节点
for (const auto &cnode : cnodes) {//对所有元素进行遍历
auto prim = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));//获取输入的所有元素
std::unique_ptr<schema::PrimitiveT> primT;
if (prim == nullptr) {//判断是否获取成功
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));//不成功则从FuncGraphPtr获取
if (fg != nullptr) {
auto partial_cnode = CreatePartialCnode(fg, cnode);//创建一个小部分的cnode
prim = GetValueNode<std::shared_ptr<Primitive>>(partial_cnode->input(0));
primT = GetPrimitiveT(partial_cnode->input(0));//获取原始数值
MS_ASSERT(primT != nullptr);
auto pos = fg_subgraph_map_.find(fg);//获取fg
if (pos != fg_subgraph_map_.end()) {//判断pos是否再尾部
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
primT->value.AsPartialFusion()->sub_graph_index = fg_subgraph_map_.at(fg);
} else {
size_t next_subgraph_index = meta_graphT->subGraph.size();
MS_ASSERT(primT->value.AsPartialFusion() != nullptr);
primT->value.AsPartialFusion()->sub_graph_index = next_subgraph_index;
ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ExportSubgraph failed";
return ret;
}
}
} else {
MS_LOG(ERROR) << "primitive_c is nullptr";
ret = RET_MEMORY_FAILED;
break;
}
}
RemoveIfDepend(cnode);//对名称冲突进行处理
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::lite::kNameTupleGetItem ||
prim->name() == mindspore::lite::kNameMakeTuple) {
continue;
}
if (prim->name() == "make_tuple") {
continue;
}
RemoveIfMakeTuple(cnode);
auto node = std::make_unique<schema::CNodeT>();//获取CNodeT
if (node == nullptr) {//判断是否为null
MS_LOG(ERROR) << "object failed to be constructed";
ret = RET_MEMORY_FAILED;
break;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {//check初始值类型
node->name = mindspore::lite::kNameReturn;//获取name值
ret = SetGraphoutputIndex(cnode, subgraph_index, meta_graphT, node.get());//设置index
if (ret != RET_OK) {//先判断是否成功
MS_LOG(ERROR) << "SetOpOutputN failed";
break;
}
continue;
}
if (primT == nullptr) {
primT = GetPrimitiveT(cnode->input(0));
}
node->name = cnode->fullname_with_scope();
node->primitive = std::move(primT);
//std::move并不能移动任何东西,它唯一的功能是将一个左值强制转化为右值引用,继而可以通过右值引用使用该值
auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
ret = SetOpInputNode(cnode, meta_graphT, node.get());//获取输入值
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetOpInputNode failed";
break;
}
SetOpOutputNode(cnode, meta_graphT, node.get());
ret = ConvertQuantParam(meta_graphT, prim, node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertQuantParam failed";
break;
}
auto status = SetPostTrainOutputTensorType(meta_graphT, prim, node);//判断输出tensor是否能成功
if (status != RET_OK) {
MS_LOG(ERROR) << "set quant output tensor data type failed.";
break;
}
meta_graphT->nodes.push_back(std::move(node));//添加node到 meta_graphT
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
}
return ret;
}