Triton SPIR-V 后端开发:新增Pass

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18916629,原文体验更佳

关于MLIR的介绍和Pass的书写我在从零开始教你写一个MLIR Pass 中写得更充实和详细,本文做为OpenMLIR/triton-spirv的开发文档,目的是方便新手上手。

本文章对应commit为08b0e35,包含include和lib文件夹的代码变动,以及根目录的CMakeLists.txt初始化。

一、include 文件夹变动

1、Pass 定义

third_party/spirv/include/Conversion/TritonToLinalg/Passes.td。Pass 定义为tablegen文件,会通过编译生成.inc文件,也就是生成器

我们定义的Pass名字为TritonToLinalg,也就是将Triton dialect转换到Linalg dialectlinalgMLIR官方在推广的dialect,曾有 Linalg at the Center,不过Triton并没有使用。以下我直接从microsoft/triton-shared拷贝的,具体内容如下所示,triton-to-linalg为二进制工具使用此Pass的选项,mlir::ModuleOp代表此Pass运行在ModuleOp上,summary的内容你可以在-help时看到,constructor是我们这个Pass的入口,会在.h定义,并在.cpp作为入口。

//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//Add commentMore actions
//===----------------------------------------------------------------------===//

#ifndef TRITON_TO_LINALG_CONVERSION_PASSES
#define TRITON_TO_LINALG_CONVERSION_PASSES

include "mlir/Pass/PassBase.td"

def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> {
  let summary = "Convert Triton to Linalg dialect";
  let constructor = "mlir::triton::spirv::createTritonToLinalgPass()";
}

#endif

2、Pass 头文件定义

third_party/spirv/include/Conversion/TritonToLinalg/Passes.h。在mlir::triton::spirv声明下上面的函数即可,另外代码里出现了#define GEN_PASS_DECL#define GEN_PASS_REGISTRATIONinclude的均是.inc文件,作用分别为展开 Pass 的类声明和 Pass 的注册代码。之后此文件只需要增加声明函数即可,具体代码如下所示。

#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H
#define TRITON_TO_LINALG_CONVERSION_PASSES_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;

namespace triton::spirv {

#define GEN_PASS_DECL
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgPass();

#define GEN_PASS_REGISTRATION
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"

} // namespace triton::spirv
} // namespace mlir

#endif

3、tablegen编译

triton-spirv/third_party/spirv/include/Conversion/TritonToLinalg
/CMakeLists.txt
。我们会使用-gen-pass-decls来生成Pass的声明,我们还需要add_public_tablegen_target把这个inc为public,供其他 CMake 目标使用。

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg)
add_public_tablegen_target(TritonToLinalgConversionPassIncGen)

编译后会得到以下代码,和.h对应上的,之后还会用到。

/* Autogenerated by mlir-tblgen; don't manually edit */

#ifdef GEN_PASS_DECL
// Generate declarations for all passes.
#define GEN_PASS_DECL_TRITONTOLINALG
#undef GEN_PASS_DECL
#endif // GEN_PASS_DECL

//===----------------------------------------------------------------------===//
// TritonToLinalg
//===----------------------------------------------------------------------===//
#ifdef GEN_PASS_DECL_TRITONTOLINALG
#undef GEN_PASS_DECL_TRITONTOLINALG
#endif // GEN_PASS_DECL_TRITONTOLINALG
#ifdef GEN_PASS_DEF_TRITONTOLINALG
namespace impl {

template <typename DerivedT>
class TritonToLinalgBase : public ::mlir::OperationPass<mlir::ModuleOp> {
public:
  using Base = TritonToLinalgBase;

  TritonToLinalgBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
  TritonToLinalgBase(const TritonToLinalgBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
  TritonToLinalgBase& operator=(const TritonToLinalgBase &) = delete;
  TritonToLinalgBase(TritonToLinalgBase &&) = delete;
  TritonToLinalgBase& operator=(TritonToLinalgBase &&) = delete;
  ~TritonToLinalgBase() = default;

  /// Returns the command-line argument attached to this pass.
  static constexpr ::llvm::StringLiteral getArgumentName() {
    return ::llvm::StringLiteral("triton-to-linalg");
  }
  ::llvm::StringRef getArgument() const override { return "triton-to-linalg"; }

  ::llvm::StringRef getDescription() const override { return "Convert Triton to Linalg dialect"; }

  /// Returns the derived pass name.
  static constexpr ::llvm::StringLiteral getPassName() {
    return ::llvm::StringLiteral("TritonToLinalg");
  }
  ::llvm::StringRef getName() const override { return "TritonToLinalg"; }

  /// Support isa/dyn_cast functionality for the derived pass class.
  static bool classof(const ::mlir::Pass *pass) {
    return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
  }

  /// A clone method to create a copy of this pass.
  std::unique_ptr<::mlir::Pass> clonePass() const override {
    return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
  }

  /// Return the dialect that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {

  }

  /// Explicitly declare the TypeID for this class. We declare an explicit private
  /// instantiation because Pass classes should only be visible by the current
  /// library.
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TritonToLinalgBase<DerivedT>)

protected:
private:
};
} // namespace impl
#undef GEN_PASS_DEF_TRITONTOLINALG
#endif // GEN_PASS_DEF_TRITONTOLINALG
#ifdef GEN_PASS_REGISTRATION

//===----------------------------------------------------------------------===//
// TritonToLinalg Registration
//===----------------------------------------------------------------------===//

inline void registerTritonToLinalg() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return mlir::triton::spirv::createTritonToLinalgPass();
  });
}

// Old registration code, kept for temporary backwards compatibility.
inline void registerTritonToLinalgPass() {
  ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
    return mlir::triton::spirv::createTritonToLinalgPass();
  });
}

//===----------------------------------------------------------------------===//
// TritonToLinalg Registration
//===----------------------------------------------------------------------===//

inline void registerTritonToLinalgPasses() {
  registerTritonToLinalg();
}
#undef GEN_PASS_REGISTRATION
#endif // GEN_PASS_REGISTRATION
// Deprecated. Please use the new per-pass macros.
#ifdef GEN_PASS_CLASSES

template <typename DerivedT>
class TritonToLinalgBase : public ::mlir::OperationPass<mlir::ModuleOp> {
public:
  using Base = TritonToLinalgBase;

  TritonToLinalgBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
  TritonToLinalgBase(const TritonToLinalgBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
  TritonToLinalgBase& operator=(const TritonToLinalgBase &) = delete;
  TritonToLinalgBase(TritonToLinalgBase &&) = delete;
  TritonToLinalgBase& operator=(TritonToLinalgBase &&) = delete;
  ~TritonToLinalgBase() = default;

  /// Returns the command-line argument attached to this pass.
  static constexpr ::llvm::StringLiteral getArgumentName() {
    return ::llvm::StringLiteral("triton-to-linalg");
  }
  ::llvm::StringRef getArgument() const override { return "triton-to-linalg"; }

  ::llvm::StringRef getDescription() const override { return "Convert Triton to Linalg dialect"; }

  /// Returns the derived pass name.
  static constexpr ::llvm::StringLiteral getPassName() {
    return ::llvm::StringLiteral("TritonToLinalg");
  }
  ::llvm::StringRef getName() const override { return "TritonToLinalg"; }

  /// Support isa/dyn_cast functionality for the derived pass class.
  static bool classof(const ::mlir::Pass *pass) {
    return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
  }

  /// A clone method to create a copy of this pass.
  std::unique_ptr<::mlir::Pass> clonePass() const override {
    return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
  }

  /// Register the dialects that must be loaded in the context before this pass.
  void getDependentDialects(::mlir::DialectRegistry &registry) const override {

  }

  /// Explicitly declare the TypeID for this class. We declare an explicit private
  /// instantiation because Pass classes should only be visible by the current
  /// library.
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TritonToLinalgBase<DerivedT>)

protected:
};
#undef GEN_PASS_CLASSES
#endif // GEN_PASS_CLASSES

4、文件目录带来的CMakeLists.txt增加

我们的代码在在third_party/spirv/include/Conversion/TritonToLinalg,所以还需要third_party/spirv/include/Conversion/CMakeLists.txtthird_party/spirv/include/CMakeLists.txt2个文件。

二、lib 文件夹变动

1、Pass书写

third_party/spirv/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp。除了include头文件外,我们还需要include.inc文件,并使用GEN_PASS_DEF_TRITONTOLINALG宏来展开,TRITONTOLINALG也就是Pass的名字。Pass的入口都是runOnOperation(),所以我们需要override这个方法。MLIR 要求 Pass 必须通过 std::unique_ptr<...> 构造和使用,方便RAII和PassManager.addPass(...),所以工厂函数要用return std::make_unique。具体代码如下所示

#include "spirv/include/Conversion/TritonToLinalg/Passes.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "triton-to-linalg"

namespace mlir::triton::spirv {
#define GEN_PASS_DEF_TRITONTOLINALG
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"
} // namespace mlir::triton::spirv

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::spirv;

namespace {

struct TritonToLinalg
    : public mlir::triton::spirv::impl::TritonToLinalgBase<TritonToLinalg> {

  void runOnOperation() override {
    auto moduleOp = getOperation();
    moduleOp.dump();
  }
};

} // namespace

namespace mlir::triton::spirv {

std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgPass() {
  return std::make_unique<TritonToLinalg>();
}

} // namespace mlir::triton::spirv

2、Pass编译

third_party/spirv/lib/Conversion/TritonToLinalg/CMakeLists.txt。设置编译目标,编译.cpp文件,链接上MLIRPassDEPENDS之前的编译td的 TritonToLinalgConversionPassIncGen即可。具体代码如下所示,之后同文件夹只需要添加.cpp文件。

add_triton_library(TritonToLinalg
  TritonToLinalg.cpp

  DEPENDS
  TritonToLinalgConversionPassIncGen

  LINK_LIBS PUBLIC
  MLIRPass
)

3、文件目录带来的CMakeLists.txt增加

我们的代码在在third_party/spirv/lib/Conversion/TritonToLinalg,所以还需要third_party/spirv/lib/Conversion/CMakeLists.txtthird_party/spirv/lib/CMakeLists.txt2个文件。

三、根目录CMakeLists.txt初始化

由于是第一次添加Pass,所以需要在根目录include_directoriesadd_subdirectory,分别为引入include文件夹内文件和让includelib参与构建

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
posted @ 2025-06-07 13:43  暴力都不会的蒟蒻  阅读(93)  评论(0)    收藏  举报