为 Paddle2ONNX 适配 swish 算子

1 简介

在 PaddlePaddle2.6 中,swish 算子在 PaddleInference 上发生了变化,删除掉了 beta 这个 Attr,因此我们需要想办法自行适配它。

2 适配过程

原解析 swish 算子的核心代码如下:

void SwishMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::string beta_node =
      helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), beta_);
  // TODO(jiangjiajun) eliminate multiply with a constant of value 1
  // TODO(jiangjiajun) eliminate add with a constant of value 0
  auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
  auto sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
  helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
                    {output_info[0].name});
}

如果仅需要适配 PaddlePaddle2.6,只需要改动为(同时还需要在类的构造函数中删除对 beta 参数的读取):

void SwishMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");
  auto sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
  helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
                    {output_info[0].name});
}

考虑到要兼容 PaddlePaddle2.5 之前的用户,因此不能直接删除掉 beta 这个参数,进一步修改如下:

void SwishMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");
  std::shared_ptr<paddle2onnx::NodeProto> sigmod_node = nullptr;

  if (HasAttr("beta")) {
    float temp_beta = 1.0;
    GetAttr("beta", &temp_beta);
    std::string beta_node = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), temp_beta);
    auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
    sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
  } else {
    sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
  }

  helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
                    {output_info[0].name});
}

3 参考链接

https://github.com/PaddlePaddle/Paddle2ONNX/pull/1190
https://github.com/PaddlePaddle/Paddle2ONNX/pull/1197

posted @ 2024-11-29 17:44  Zheng-Bicheng  阅读(60)  评论(0)    收藏  举报