代码注释
//转换数量参数
int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<mindspore::Primitive> &primitive,
const std::unique_ptr<schema::CNodeT> &dst_node) {
MS_ASSERT(meta_graph != nullptr);
MS_ASSERT(primitive != nullptr);
MS_ASSERT(dst_node != nullptr);
// add quant param
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
// activation
QuantParamsVector input_quant_params;//定义输入和输出的两个数量参数容器
QuantParamsVector output_quant_params;
dst_node->quantType = schema::QuantType_QUANT_NONE;
auto quant_tensor_info_ptr = primitive->GetAttr("quant_params");//将参数导出
QuantParamHolderPtr quant_param_holder = nullptr;
if (quant_tensor_info_ptr == nullptr ||
(quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>()) == nullptr) { //判断param是为为null
quant_param_holder = std::make_shared<QuantParamHolder>(dst_node->inputIndex.size(), dst_node->outputIndex.size());
}// std::make_shared可以返回一个指定类型的std::shared_ptr:需要维护引用计数的信息
input_quant_params = quant_param_holder->get_input_quant_params();//对两种数量参数分别从压缩了的数量参数中取出
output_quant_params = quant_param_holder->get_output_quant_params();
dst_node->quantType = quant_param_holder->quant_type();
// convert input quant param
for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {//遍历转换input的数量参数
if (i >= input_quant_params.size()) {
MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << ", but only has"
<< input_quant_params.size() << " quant params";
break;
}
auto activate_index = dst_node->inputIndex[i];
auto tensor_input = meta_graph->allTensors[activate_index].get();//获取所有输入的tensor
if (tensor_input->quantParams.empty()) {//判断输入时tensor中数量参数是否为空
for (auto input_quant_param : input_quant_params[i]) {//对 input_quant_param进行一个遍历
auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;//进行DeBUg
input_quant_param_ptr->dstDtype = tensor_input->dataType;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
}
if (CompressTensor(tensor_input, dst_node) != RET_OK) {//判断是否压缩成功,成功回返回retok
MS_LOG(ERROR) << "CompressTensor error";
return RET_ERROR;
}
}
// output
int output_idx = 0;//准备遍历
for (const auto &output_quant_param : output_quant_params) {//遍历output参数
auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[output_idx]].get();//初始化 output_tensor
output_idx++;
for (const auto &channel_quant_param : output_quant_param) {//对所有的 output_quant_param依次遍历进行压缩
if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(channel_quant_param);//转换成unique_ptr使内存占用率更少
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
output_quant_param_ptr->dstDtype = output_tensor->dataType;
output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
}
}
}
return RET_OK;//成功返回 RET_OK
}
//获取子图节点
std::vector<schema::CNodeT *> AnfExporter::GetSubgraphNodes(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index) {
std::vector<schema::CNodeT *> subgraph_nodes{};//>* 用来调用指向类成员的函数指针。
subgraph_nodes.resize(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.size());//调整图像大小为合适的大小
//std::transform在指定的范围内应用于给定的操作,并将结果存储在指定的另一个范围内
std::transform(meta_graphT->subGraph.at(subgraph_index)->nodeIndices.begin(),
meta_graphT->subGraph.at(subgraph_index)->nodeIndices.end(), subgraph_nodes.begin(),
[&meta_graphT](const uint32_t idx) { return meta_graphT->nodes.at(idx).get(); });
return subgraph_nodes;
}
//设置图像输入的索引值函数
int AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
const size_t &subgraph_index) {
auto &subgraph = meta_graphT->subGraph.at(subgraph_index);//初始化子图索引值
auto subgraph_nodes = GetSubgraphNodes(meta_graphT, subgraph_index);//获取图片的节点
std::vector<schema::CNodeT *> subgraph_input_nodes{};
for (auto &node : subgraph_nodes) {//遍历子图节点
if (IsContain(graph_input_nodes_, node)) {//判断是否包含该节点
subgraph_input_nodes.push_back(node);//添加该节点
}
}
std::vector<schema::TensorT *> subgraph_inputs{};
for (auto &node : subgraph_input_nodes) {//遍历输入的图片
for (auto input : node->inputIndex) {
auto tensor = meta_graphT->allTensors[input].get();
if (tensor->nodeType != NodeType_CNode && tensor->data.empty()) {
tensor->nodeType = NodeType_ValueNode;
tensor->format = schema::Format_NHWC;
if (!IsContain(subgraph->inputIndices, input)) {
if (subgraph_index == kMainGraphIndex) {
TensorDataType::GetInstance()->UpdateGraphInputDType(meta_graphT->inputIndex.size(), tensor->dataType);
meta_graphT->inputIndex.push_back(input);
}
subgraph->inputIndices.push_back(input);
subgraph_inputs.push_back(tensor);
}
}
}
}
return RET_OK;
}