mlir 基本操作

mlir 基本操作

在 MLIR 中,mlir::Operation 类提供了丰富的方法,用于操作和分析 IR 中的操作(Operation)。以下是 mlir::Operation 的常用方法及其功能的详细说明:

MLIR Operation API 参考

1. 官方文档

2. Operation 相关 API 参考

3. 代码示例与指南

4. 相关工具

MLIR Operation API 参考

MLIR中的Operation是核心抽象之一,以下是MLIR Operation API的详细参考,包括基本信息、操作数、结果、区域、块、属性等方面。


1. 基本信息

getName()

  • 功能: 获取操作的名称。
  • 返回值: mlir::OperationName 对象。
  • 示例:
    mlir::OperationName name = op->getName();
    

getLoc()

  • 功能: 获取操作的位置信息(Location)。
  • 返回值: mlir::Location 对象。
  • 示例:
    mlir::Location loc = op->getLoc();
    

getContext()

  • 功能: 获取操作所属的 MLIR 上下文。
  • 返回值: mlir::MLIRContext* 指针。
  • 示例:
    mlir::MLIRContext* context = op->getContext();
    

2. 操作数(Operands)

getOperands()

  • 功能: 获取操作的所有操作数。
  • 返回值: mlir::Operation::operand_range 对象。
  • 示例:
    for (mlir::Value operand : op->getOperands()) { ... }
    

getOperand(unsigned index)

  • 功能: 获取操作的第 index 个操作数。
  • 返回值: mlir::Value 对象。
  • 示例:
    mlir::Value operand = op->getOperand(0);
    

getNumOperands()

  • 功能: 获取操作的操作数数量。
  • 返回值: unsigned 值。
  • 示例:
    unsigned numOperands = op->getNumOperands();
    

setOperand(unsigned index, mlir::Value value)

  • 功能: 设置操作的第 index 个操作数。
  • 示例:
    op->setOperand(0, newValue);
    

3. 结果(Results)

getResults()

  • 功能: 获取操作的所有结果。
  • 返回值: mlir::Operation::result_range 对象。
  • 示例:
    for (mlir::Value result : op->getResults()) { ... }
    

getResult(unsigned index)

  • 功能: 获取操作的第 index 个结果。
  • 返回值: mlir::Value 对象。
  • 示例:
    mlir::Value result = op->getResult(0);
    

getNumResults()

  • 功能: 获取操作的结果数量。
  • 返回值: unsigned 值。
  • 示例:
    unsigned numResults = op->getNumResults();
    

4. 区域(Regions)

getRegions()

  • 功能: 获取操作的所有区域。
  • 返回值: mlir::Operation::region_range 对象。
  • 示例:
    for (mlir::Region& region : op->getRegions()) { ... }
    

getRegion(unsigned index)

  • 功能: 获取操作的第 index 个区域。
  • 返回值: mlir::Region 对象。
  • 示例:
    mlir::Region& region = op->getRegion(0);
    

getNumRegions()

  • 功能: 获取操作的区域数量。
  • 返回值: unsigned 值。
  • 示例:
    unsigned numRegions = op->getNumRegions();
    

5. 块(Blocks)

getBlocks()

  • 功能: 获取操作的所有块。
  • 返回值: mlir::Operation::block_range 对象。
  • 示例:
    for (mlir::Block& block : op->getBlocks()) { ... }
    

getBlock(unsigned index)

  • 功能: 获取操作的第 index 个块。
  • 返回值: mlir::Block 对象。
  • 示例:
    mlir::Block& block = op->getBlock(0);
    

getNumBlocks()

  • 功能: 获取操作的块数量。
  • 返回值: unsigned 值。
  • 示例:
    unsigned numBlocks = op->getNumBlocks();
    

6. 属性(Attributes)

getAttrs()

  • 功能: 获取操作的所有属性。
  • 返回值: mlir::DictionaryAttr 对象。
  • 示例:
    mlir::DictionaryAttr attrs = op->getAttrs();
    

getAttr(StringRef name)

  • 功能: 获取操作的指定属性。
  • 返回值: mlir::Attribute 对象。
  • 示例:
    mlir::Attribute attr = op->getAttr("my_attr");
    

setAttr(StringRef name, mlir::Attribute attr)

  • 功能: 设置操作的指定属性。
  • 示例:
    op->setAttr("my_attr", newAttr);
    

removeAttr(StringRef name)

  • 功能: 移除操作的指定属性。
  • 示例:
    op->removeAttr("my_attr");
    

7. 其他方法

dump()

  • 功能: 将操作的信息打印到标准错误流(stderr)。
  • 示例:
    op->dump();
    

print(raw_ostream &os)

  • 功能: 将操作的信息打印到指定的输出流。
  • 示例:
    op->print(llvm::outs());
    

clone()

  • 功能: 克隆操作。
  • 返回值: mlir::Operation* 指针。
  • 示例:
    mlir::Operation* clonedOp = op->clone();
    

erase()

  • 功能: 删除操作。
  • 示例:
    op->erase();
    

8. 总结

类别 方法 功能
基本信息 getName() 获取操作名称。
getLoc() 获取操作位置信息。
getContext() 获取操作所属的 MLIR 上下文。
操作数 getOperands() 获取所有操作数。
getOperand(index) 获取特定操作数。
getNumOperands() 获取操作数数量。
setOperand(index, value) 设置特定操作数。
结果 getResults() 获取所有结果。
getResult(index) 获取特定结果。
getNumResults() 获取结果数量。
区域 getRegions() 获取所有区域。
getRegion(index) 获取特定区域。
getNumRegions() 获取区域数量。
getBlocks() 获取所有块。
getBlock(index) 获取特定块。
getNumBlocks() 获取块数量。
属性 getAttrs() 获取所有属性。
getAttr(name) 获取特定属性。
setAttr(name, attr) 设置特定属性。
removeAttr(name) 移除特定属性。
其他 dump() 打印操作信息到 stderr。
print(os) 打印操作信息到指定输出流。
clone() 克隆操作。
erase() 删除操作。

  // 获取操作名称
  llvm::outs() << "Operation name: " << op->getName() << "\n";
	
  llvm::outs() << "Operation name: " << op->getName() << "\n";
  // 获取操作数
  for (mlir::Value operand : op->getOperands()) {
    llvm::outs() << "Operand: " << operand << "\n";
  }
  // 获取结果
  for (mlir::Value result : op->getResults()) {
    llvm::outs() << "Result: " << result << "\n";
  }
  // 获取属性
  for (mlir::NamedAttribute attr : op->getAttrs()) {
    llvm::outs() << "Attribute: " << attr.getName() << " = " << attr.getValue() << "\n";
  }
  
  Location loc = op.getLoc();
// 获取block 入口参数,参数类型,父block, 行号
mlir::Block *block = op->getBlock();
for (mlir::BlockArgument arg : block->getArguments()) {
    llvm::outs() << "  Argument: " << arg << "\n";
    llvm::outs() << "    Type: " << arg.getType() << "\n";
    llvm::outs() << "    Parent Block: " << arg.getParentBlock() << "\n";
    llvm::outs() << "    Location: " << arg.getLoc() << "\n";
    auto custom_tensor =
        rewriter.create<custom::ConstLikeOp>(loc, arg.getType(), arg);
  }

以下是获取块输入参数的主要方法:

方法 说明
getArguments() 获取块的所有输入参数。
getArgument(index) 获取块的特定输入参数。
getNumArguments() 获取块的输入参数数量。
getType() 获取输入参数的类型。
getParentBlock() 获取输入参数所属的块。
getLoc() 获取输入参数的位置信息。
getUses() 获取输入参数的使用。
// op-->operand-->op
if (mlir::Operation *op = operand.getDefiningOp()) {
	llvm::outs() << "Defining operation: " << op->getName() << "\n";
	op->dump(); // 打印操作的详细信息
	for (mlir::Value operand : op->getOperands()) {
		llvm::outs() << "  Operand: " << operand << "\n";
		if (mlir::Operation* op = operand.getDefiningOp()) {
			llvm::outs() << "Defining operation: " << op->getName() << "\n";
		} else {
			llvm::outs() << "  Constant: " << operand << "\n";
		}
	}
}

// op-->operand-->op
  if (mlir::Operation *optmpout = padTensor.getDefiningOp()) {
    // llvm::outs() << "Defining operation: " << op->getName() << "\n";
    for (mlir::Value operand : optmpout->getOperands()) {
      // llvm::outs() << "  Operand: " << operand << "\n";
      // if (mlir::Operation *op = operand.getDefiningOp<stc::ConstOp>()) {
      if (auto constantOp = operand.getDefiningOp<mlir::arith::ConstantOp>()) {
        llvm::outs() << "Defining operation: "<< constantOp.getValue()->getValues<float>()[0] << "\n";
		}
	}
}

SmallVector<Value> customTensors;
SmallVector<Value> resultTensors;

void getcustomTensors(const llvm::SmallVector<mlir::Value> &values,
                   SmallVector<mlir::Value> &results) {
  for (size_t i = 0; i < values.size(); ++i) {
    mlir::Type type = values[i].getType();
    if (auto tensorType = type.dyn_cast<mlir::TensorType>()) {
      ele_type = tensorType.getElementType();

      auto const_k =
          stc::getConstTensor<float>(rewriter, op, vec, newShape).value();
      llvm::outs() << "TensorTorchType: " << tensorType << "\n";
      llvm::outs() << "    Element Type: " << tensorType.getElementType()
                   << "\n"; // 打印元素类型
      llvm::outs() << "    Shape: ";
      for (int64_t dim : tensorType.getShape()) { // 打印形状
        llvm::outs() << dim << " ";
      }
      llvm::outs() << "\n";
    } else {
      llvm::outs() << "Not a TensorTorchType: " << type << "\n";
    }
  }
}


void ToBuiltinTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                    MLIRContext *context) {
  patterns.add(+[](ToBuiltinTensorOp op, PatternRewriter &rewriter) {
    auto operand = op.getOperand();

    auto fromBuiltinOp = operand.getDefiningOp<FromBuiltinTensorOp>();
    if (!fromBuiltinOp)
      return failure();

    rewriter.replaceOp(op, fromBuiltinOp.getOperand());
    return success();
  });
}
posted @ 2025-03-12 11:45  michaelchengjl  阅读(266)  评论(0)    收藏  举报