为 Paddle2ONNX 适配 swish 算子
1 简介
在 PaddlePaddle2.6 中,swish 算子在 PaddleInference 上发生了变化,删除掉了 beta 这个 Attr,因此我们需要想办法自行适配它。
2 适配过程
原解析 swish 算子的核心代码如下:
void SwishMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::string beta_node =
helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), beta_);
// TODO(jiangjiajun) eliminate multiply with a constant of value 1
// TODO(jiangjiajun) eliminate add with a constant of value 0
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
auto sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
如果仅需要适配 PaddlePaddle2.6,只需要改动为(同时还需要在类的构造函数中删除对 beta 参数的读取):
void SwishMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
考虑到要兼容 PaddlePaddle2.5 之前的用户,因此不能直接删除掉 beta 这个参数,进一步修改如下:
void SwishMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::shared_ptr<paddle2onnx::NodeProto> sigmod_node = nullptr;
if (HasAttr("beta")) {
float temp_beta = 1.0;
GetAttr("beta", &temp_beta);
std::string beta_node = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), temp_beta);
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
} else {
sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
}
helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
3 参考链接
● https://github.com/PaddlePaddle/Paddle2ONNX/pull/1190
● https://github.com/PaddlePaddle/Paddle2ONNX/pull/1197

浙公网安备 33010602011771号