为 Paddle2ONNX 修复 elementwise_floordiv 算子计算错误的问题

1 简介

elementwise_floordiv 算子在 int32/int64 的情况下直接转换成了 ONNX 中的 div 算子,由于 div 算子是普通除操作,而不是整除操作,因此无法通过 CI 的校验。

2 实现过程

原核心实现代码如下

void ElementWiseFloordivMapper::Opset7() {
    auto input_x_info = GetInput("X");
    auto input_y_info = GetInput("Y");
    auto output_info = GetOutput("Out");

    bool is_int = false;
    if (input_x_info[0].dtype <= 3 || input_x_info[0].dtype == 20 ||
        input_y_info[0].dtype <= 3 || input_y_info[0].dtype == 20) {
        is_int = true;
    }
    if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 ||
        input_x_info[0].Rank() == input_y_info[0].Rank()) {
        if (is_int) {
            helper_->MakeNode("Div", {input_x_info[0].name, input_y_info[0].name},
                {output_info[0].name});
        } else {
            auto div_node = helper_->MakeNode(
            "Div", {input_x_info[0].name, input_y_info[0].name});
            helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
        }
    } else {
        std::vector<int64_t> broadcast_shape;
        broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
        for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
            broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
        }
        std::string broadcast_shape_node =
        helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
        auto y_node = helper_->MakeNode(
        "Reshape", {input_y_info[0].name, broadcast_shape_node});
        if (is_int) {
            helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)},
                {output_info[0].name});
        } else {
            auto div_node =
            helper_->MakeNode("Div", {input_x_info[0].name, y_node->output(0)});
            helper_->MakeNode("Floor", {div_node->output(0)}, {output_info[0].name});
        }
    }
}

可以看到,针对 int 的情况,原转换函数直接将 elementwise_floordiv 算子转换成了 Div 算子,这显然缺少了一个 floor 操作,因此修改为如下代码:

void ElementWiseFloordivMapper::Opset7() {
  auto input_x_info = GetInput("X");
  auto input_y_info = GetInput("Y");
  auto output_info = GetOutput("Out");

  auto div_input_0 = helper_->AutoCast(input_x_info[0].name, input_x_info[0].dtype, P2ODataType::FP32);
  auto div_input_1 = helper_->AutoCast(input_y_info[0].name, input_y_info[0].dtype, P2ODataType::FP32);

 if (axis_ == -1 || axis_ == input_x_info[0].Rank() - 1 || input_x_info[0].Rank() == input_y_info[0].Rank()) {
    auto div_node = helper_->MakeNode("Div", {div_input_0, div_input_1});
    auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
    helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
  } else {
    std::vector<int64_t> broadcast_shape;
    broadcast_shape.resize(axis_ + input_x_info[0].Rank(), 1);
    for (auto i = 0; i < input_y_info[0].Rank(); ++i) {
      broadcast_shape[axis_ + i] = input_y_info[0].shape[i];
    }
    std::string broadcast_shape_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), broadcast_shape);
    auto y_node = helper_->MakeNode("Reshape", {div_input_1, broadcast_shape_node});
    auto div_node = helper_->MakeNode("Div", {div_input_0, y_node->output(0)});
    auto floor_output = helper_->MakeNode("Floor", {div_node->output(0)});
    helper_->AutoCast(floor_output->output(0), output_info[0].name, P2ODataType::FP32, output_info[0].dtype);
  }
}

3 参考资料

posted @ 2024-11-29 17:48  Zheng-Bicheng  阅读(30)  评论(0)    收藏  举报