mlir 基本操作
mlir 基本操作
在 MLIR 中,mlir::Operation 类提供了丰富的方法,用于操作和分析 IR 中的操作(Operation)。以下是 mlir::Operation 的常用方法及其功能的详细说明:
MLIR Operation API 参考
1. 官方文档
2. Operation 相关 API 参考
3. 代码示例与指南
4. 相关工具
MLIR Operation API 参考
MLIR中的Operation是核心抽象之一,以下是MLIR Operation API的详细参考,包括基本信息、操作数、结果、区域、块、属性等方面。
1. 基本信息
getName()
- 功能: 获取操作的名称。
- 返回值:
mlir::OperationName对象。 - 示例:
mlir::OperationName name = op->getName();
getLoc()
- 功能: 获取操作的位置信息(Location)。
- 返回值: mlir::Location 对象。
- 示例:
mlir::Location loc = op->getLoc();
getContext()
- 功能: 获取操作所属的 MLIR 上下文。
- 返回值: mlir::MLIRContext* 指针。
- 示例:
mlir::MLIRContext* context = op->getContext();
2. 操作数(Operands)
getOperands()
- 功能: 获取操作的所有操作数。
- 返回值: mlir::Operation::operand_range 对象。
- 示例:
for (mlir::Value operand : op->getOperands()) { ... }
getOperand(unsigned index)
- 功能: 获取操作的第 index 个操作数。
- 返回值: mlir::Value 对象。
- 示例:
mlir::Value operand = op->getOperand(0);
getNumOperands()
- 功能: 获取操作的操作数数量。
- 返回值: unsigned 值。
- 示例:
unsigned numOperands = op->getNumOperands();
setOperand(unsigned index, mlir::Value value)
- 功能: 设置操作的第 index 个操作数。
- 示例:
op->setOperand(0, newValue);
3. 结果(Results)
getResults()
- 功能: 获取操作的所有结果。
- 返回值: mlir::Operation::result_range 对象。
- 示例:
for (mlir::Value result : op->getResults()) { ... }
getResult(unsigned index)
- 功能: 获取操作的第 index 个结果。
- 返回值: mlir::Value 对象。
- 示例:
mlir::Value result = op->getResult(0);
getNumResults()
- 功能: 获取操作的结果数量。
- 返回值: unsigned 值。
- 示例:
unsigned numResults = op->getNumResults();
4. 区域(Regions)
getRegions()
- 功能: 获取操作的所有区域。
- 返回值: mlir::Operation::region_range 对象。
- 示例:
for (mlir::Region& region : op->getRegions()) { ... }
getRegion(unsigned index)
- 功能: 获取操作的第 index 个区域。
- 返回值: mlir::Region 对象。
- 示例:
mlir::Region& region = op->getRegion(0);
getNumRegions()
- 功能: 获取操作的区域数量。
- 返回值: unsigned 值。
- 示例:
unsigned numRegions = op->getNumRegions();
5. 块(Blocks)
getBlocks()
- 功能: 获取操作的所有块。
- 返回值: mlir::Operation::block_range 对象。
- 示例:
for (mlir::Block& block : op->getBlocks()) { ... }
getBlock(unsigned index)
- 功能: 获取操作的第 index 个块。
- 返回值: mlir::Block 对象。
- 示例:
mlir::Block& block = op->getBlock(0);
getNumBlocks()
- 功能: 获取操作的块数量。
- 返回值: unsigned 值。
- 示例:
unsigned numBlocks = op->getNumBlocks();
6. 属性(Attributes)
getAttrs()
- 功能: 获取操作的所有属性。
- 返回值: mlir::DictionaryAttr 对象。
- 示例:
mlir::DictionaryAttr attrs = op->getAttrs();
getAttr(StringRef name)
- 功能: 获取操作的指定属性。
- 返回值: mlir::Attribute 对象。
- 示例:
mlir::Attribute attr = op->getAttr("my_attr");
setAttr(StringRef name, mlir::Attribute attr)
- 功能: 设置操作的指定属性。
- 示例:
op->setAttr("my_attr", newAttr);
removeAttr(StringRef name)
- 功能: 移除操作的指定属性。
- 示例:
op->removeAttr("my_attr");
7. 其他方法
dump()
- 功能: 将操作的信息打印到标准错误流(stderr)。
- 示例:
op->dump();
print(raw_ostream &os)
- 功能: 将操作的信息打印到指定的输出流。
- 示例:
op->print(llvm::outs());
clone()
- 功能: 克隆操作。
- 返回值: mlir::Operation* 指针。
- 示例:
mlir::Operation* clonedOp = op->clone();
erase()
- 功能: 删除操作。
- 示例:
op->erase();
8. 总结
| 类别 | 方法 | 功能 |
|---|---|---|
| 基本信息 | getName() | 获取操作名称。 |
| getLoc() | 获取操作位置信息。 | |
| getContext() | 获取操作所属的 MLIR 上下文。 | |
| 操作数 | getOperands() | 获取所有操作数。 |
| getOperand(index) | 获取特定操作数。 | |
| getNumOperands() | 获取操作数数量。 | |
| setOperand(index, value) | 设置特定操作数。 | |
| 结果 | getResults() | 获取所有结果。 |
| getResult(index) | 获取特定结果。 | |
| getNumResults() | 获取结果数量。 | |
| 区域 | getRegions() | 获取所有区域。 |
| getRegion(index) | 获取特定区域。 | |
| getNumRegions() | 获取区域数量。 | |
| 块 | getBlocks() | 获取所有块。 |
| getBlock(index) | 获取特定块。 | |
| getNumBlocks() | 获取块数量。 | |
| 属性 | getAttrs() | 获取所有属性。 |
| getAttr(name) | 获取特定属性。 | |
| setAttr(name, attr) | 设置特定属性。 | |
| removeAttr(name) | 移除特定属性。 | |
| 其他 | dump() | 打印操作信息到 stderr。 |
| print(os) | 打印操作信息到指定输出流。 | |
| clone() | 克隆操作。 | |
| erase() | 删除操作。 |
// 获取操作名称
llvm::outs() << "Operation name: " << op->getName() << "\n";
llvm::outs() << "Operation name: " << op->getName() << "\n";
// 获取操作数
for (mlir::Value operand : op->getOperands()) {
llvm::outs() << "Operand: " << operand << "\n";
}
// 获取结果
for (mlir::Value result : op->getResults()) {
llvm::outs() << "Result: " << result << "\n";
}
// 获取属性
for (mlir::NamedAttribute attr : op->getAttrs()) {
llvm::outs() << "Attribute: " << attr.getName() << " = " << attr.getValue() << "\n";
}
Location loc = op.getLoc();
// 获取block 入口参数,参数类型,父block, 行号
mlir::Block *block = op->getBlock();
for (mlir::BlockArgument arg : block->getArguments()) {
llvm::outs() << " Argument: " << arg << "\n";
llvm::outs() << " Type: " << arg.getType() << "\n";
llvm::outs() << " Parent Block: " << arg.getParentBlock() << "\n";
llvm::outs() << " Location: " << arg.getLoc() << "\n";
auto custom_tensor =
rewriter.create<custom::ConstLikeOp>(loc, arg.getType(), arg);
}
以下是获取块输入参数的主要方法:
| 方法 | 说明 |
|---|---|
| getArguments() | 获取块的所有输入参数。 |
| getArgument(index) | 获取块的特定输入参数。 |
| getNumArguments() | 获取块的输入参数数量。 |
| getType() | 获取输入参数的类型。 |
| getParentBlock() | 获取输入参数所属的块。 |
| getLoc() | 获取输入参数的位置信息。 |
| getUses() | 获取输入参数的使用。 |
// op-->operand-->op
if (mlir::Operation *op = operand.getDefiningOp()) {
llvm::outs() << "Defining operation: " << op->getName() << "\n";
op->dump(); // 打印操作的详细信息
for (mlir::Value operand : op->getOperands()) {
llvm::outs() << " Operand: " << operand << "\n";
if (mlir::Operation* op = operand.getDefiningOp()) {
llvm::outs() << "Defining operation: " << op->getName() << "\n";
} else {
llvm::outs() << " Constant: " << operand << "\n";
}
}
}
// op-->operand-->op
if (mlir::Operation *optmpout = padTensor.getDefiningOp()) {
// llvm::outs() << "Defining operation: " << op->getName() << "\n";
for (mlir::Value operand : optmpout->getOperands()) {
// llvm::outs() << " Operand: " << operand << "\n";
// if (mlir::Operation *op = operand.getDefiningOp<stc::ConstOp>()) {
if (auto constantOp = operand.getDefiningOp<mlir::arith::ConstantOp>()) {
llvm::outs() << "Defining operation: "<< constantOp.getValue()->getValues<float>()[0] << "\n";
}
}
}
SmallVector<Value> customTensors;
SmallVector<Value> resultTensors;
void getcustomTensors(const llvm::SmallVector<mlir::Value> &values,
SmallVector<mlir::Value> &results) {
for (size_t i = 0; i < values.size(); ++i) {
mlir::Type type = values[i].getType();
if (auto tensorType = type.dyn_cast<mlir::TensorType>()) {
ele_type = tensorType.getElementType();
auto const_k =
stc::getConstTensor<float>(rewriter, op, vec, newShape).value();
llvm::outs() << "TensorTorchType: " << tensorType << "\n";
llvm::outs() << " Element Type: " << tensorType.getElementType()
<< "\n"; // 打印元素类型
llvm::outs() << " Shape: ";
for (int64_t dim : tensorType.getShape()) { // 打印形状
llvm::outs() << dim << " ";
}
llvm::outs() << "\n";
} else {
llvm::outs() << "Not a TensorTorchType: " << type << "\n";
}
}
}
void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) {
auto operand = op.getOperand();
auto fromBuiltinOp = operand.getDefiningOp<FromBuiltinTensorOp>();
if (!fromBuiltinOp)
return failure();
rewriter.replaceOp(op, fromBuiltinOp.getOperand());
return success();
});
}

浙公网安备 33010602011771号