为 Paddle2ONNX 适配 Roll 算子
1 简介
Roll 算子一般被用在 Swin 结构中,Paddle2ONNX 暂时不支持该算子,本教程介绍如何为 Paddle2ONNX 添加 Roll 算子。
2 实现过程
2.1 roll 算子简介
def roll(x, shifts, axis=None, name=None):
"""
Roll the `x` tensor along the given axis(axes). With specific 'shifts', Elements that
roll beyond the last position are re-introduced at the first according to 'shifts'.
If a axis is not specified,
the tensor will be flattened before rolling and then restored to the original shape.
Args:
x (Tensor): The x tensor as input.
shifts (int|list|tuple): The number of places by which the elements
of the `x` tensor are shifted.
axis (int|list|tuple, optional): axis(axes) along which to roll. Default: None
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Tensor, A Tensor with same data type as `x`.
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([[1.0, 2.0, 3.0],
... [4.0, 5.0, 6.0],
... [7.0, 8.0, 9.0]])
>>> out_z1 = paddle.roll(x, shifts=1)
>>> print(out_z1.numpy())
[[9. 1. 2.]
[3. 4. 5.]
[6. 7. 8.]]
>>> out_z2 = paddle.roll(x, shifts=1, axis=0)
>>> print(out_z2.numpy())
[[7. 8. 9.]
[1. 2. 3.]
[4. 5. 6.]]
>>> out_z3 = paddle.roll(x, shifts=1, axis=1)
>>> print(out_z3.numpy())
[[3. 1. 2.]
[6. 4. 5.]
[9. 7. 8.]]
"""
2.2 在 Paddle2ONNX 中实现 Roll 算子
首先在 paddle2onnx/mapper/tensor 下新建 roll.h 并添加对 RollMapper 类的定义
#pragma once
#include <string>
#include <vector>
#include "paddle2onnx/mapper/mapper.h"
namespace paddle2onnx {
class RollMapper : public Mapper {
public:
RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset7();
};
} // namespace paddle2onnx
接下来在 paddle2onnx/mapper/tensor 下新建 roll.cc 并添加对 RollMapper 类的实现
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
void RollMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::vector<int64_t> shifts;
GetAttr("shifts", &shifts);
std::vector<int64_t> axis;
GetAttr("axis", &axis);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
auto result_name = input_info[0].name;
if (axis.empty())
{
// For axis is None
int64_t axes = 0;
result_name = helper_->Flatten(result_name);
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
} else {
// For axis is not None
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
int64_t axes = axis[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
if(i+1 == shifts.size()) {
temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
} else {
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
}
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
}
}
} // namespace paddle2onnx

浙公网安备 33010602011771号