Triton SPIR-V 后端开发:新增Pass
本博客原文地址:https://www.cnblogs.com/BobHuang/p/18916629,原文体验更佳
关于MLIR的介绍和Pass的书写我在从零开始教你写一个MLIR Pass 中写得更充实和详细,本文做为OpenMLIR/triton-spirv的开发文档,目的是方便新手上手。
本文章对应commit为08b0e35,包含include和lib文件夹的代码变动,以及根目录的CMakeLists.txt
初始化。
一、include 文件夹变动
1、Pass 定义
third_party/spirv/include/Conversion/TritonToLinalg/Passes.td。Pass 定义为tablegen
文件,会通过编译生成.inc
文件,也就是生成器
我们定义的Pass名字为TritonToLinalg
,也就是将Triton dialect
转换到Linalg dialect
,linalg
是MLIR
官方在推广的dialect
,曾有 Linalg at the Center,不过Triton
并没有使用。以下我直接从microsoft/triton-shared拷贝的,具体内容如下所示,triton-to-linalg
为二进制工具使用此Pass的选项,mlir::ModuleOp
代表此Pass运行在ModuleOp
上,summary
的内容你可以在-help
时看到,constructor
是我们这个Pass的入口,会在.h
定义,并在.cpp
作为入口。
//===----------------------------------------------------------------------===//
//
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
//Add commentMore actions
//===----------------------------------------------------------------------===//
#ifndef TRITON_TO_LINALG_CONVERSION_PASSES
#define TRITON_TO_LINALG_CONVERSION_PASSES
include "mlir/Pass/PassBase.td"
def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> {
let summary = "Convert Triton to Linalg dialect";
let constructor = "mlir::triton::spirv::createTritonToLinalgPass()";
}
#endif
2、Pass 头文件定义
third_party/spirv/include/Conversion/TritonToLinalg/Passes.h。在mlir::triton::spirv
声明下上面的函数即可,另外代码里出现了#define GEN_PASS_DECL
和#define GEN_PASS_REGISTRATION
且include
的均是.inc
文件,作用分别为展开 Pass 的类声明和 Pass 的注册代码。之后此文件只需要增加声明函数即可,具体代码如下所示。
#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H
#define TRITON_TO_LINALG_CONVERSION_PASSES_H
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton::spirv {
#define GEN_PASS_DECL
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"
std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgPass();
#define GEN_PASS_REGISTRATION
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"
} // namespace triton::spirv
} // namespace mlir
#endif
3、tablegen编译
triton-spirv/third_party/spirv/include/Conversion/TritonToLinalg
/CMakeLists.txt。我们会使用-gen-pass-decls
来生成Pass的声明,我们还需要add_public_tablegen_target
把这个inc为public
,供其他 CMake 目标使用。
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg)
add_public_tablegen_target(TritonToLinalgConversionPassIncGen)
编译后会得到以下代码,和.h
对应上的,之后还会用到。
/* Autogenerated by mlir-tblgen; don't manually edit */
#ifdef GEN_PASS_DECL
// Generate declarations for all passes.
#define GEN_PASS_DECL_TRITONTOLINALG
#undef GEN_PASS_DECL
#endif // GEN_PASS_DECL
//===----------------------------------------------------------------------===//
// TritonToLinalg
//===----------------------------------------------------------------------===//
#ifdef GEN_PASS_DECL_TRITONTOLINALG
#undef GEN_PASS_DECL_TRITONTOLINALG
#endif // GEN_PASS_DECL_TRITONTOLINALG
#ifdef GEN_PASS_DEF_TRITONTOLINALG
namespace impl {
template <typename DerivedT>
class TritonToLinalgBase : public ::mlir::OperationPass<mlir::ModuleOp> {
public:
using Base = TritonToLinalgBase;
TritonToLinalgBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
TritonToLinalgBase(const TritonToLinalgBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
TritonToLinalgBase& operator=(const TritonToLinalgBase &) = delete;
TritonToLinalgBase(TritonToLinalgBase &&) = delete;
TritonToLinalgBase& operator=(TritonToLinalgBase &&) = delete;
~TritonToLinalgBase() = default;
/// Returns the command-line argument attached to this pass.
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("triton-to-linalg");
}
::llvm::StringRef getArgument() const override { return "triton-to-linalg"; }
::llvm::StringRef getDescription() const override { return "Convert Triton to Linalg dialect"; }
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("TritonToLinalg");
}
::llvm::StringRef getName() const override { return "TritonToLinalg"; }
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
}
/// Explicitly declare the TypeID for this class. We declare an explicit private
/// instantiation because Pass classes should only be visible by the current
/// library.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TritonToLinalgBase<DerivedT>)
protected:
private:
};
} // namespace impl
#undef GEN_PASS_DEF_TRITONTOLINALG
#endif // GEN_PASS_DEF_TRITONTOLINALG
#ifdef GEN_PASS_REGISTRATION
//===----------------------------------------------------------------------===//
// TritonToLinalg Registration
//===----------------------------------------------------------------------===//
inline void registerTritonToLinalg() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::triton::spirv::createTritonToLinalgPass();
});
}
// Old registration code, kept for temporary backwards compatibility.
inline void registerTritonToLinalgPass() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::triton::spirv::createTritonToLinalgPass();
});
}
//===----------------------------------------------------------------------===//
// TritonToLinalg Registration
//===----------------------------------------------------------------------===//
inline void registerTritonToLinalgPasses() {
registerTritonToLinalg();
}
#undef GEN_PASS_REGISTRATION
#endif // GEN_PASS_REGISTRATION
// Deprecated. Please use the new per-pass macros.
#ifdef GEN_PASS_CLASSES
template <typename DerivedT>
class TritonToLinalgBase : public ::mlir::OperationPass<mlir::ModuleOp> {
public:
using Base = TritonToLinalgBase;
TritonToLinalgBase() : ::mlir::OperationPass<mlir::ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
TritonToLinalgBase(const TritonToLinalgBase &other) : ::mlir::OperationPass<mlir::ModuleOp>(other) {}
TritonToLinalgBase& operator=(const TritonToLinalgBase &) = delete;
TritonToLinalgBase(TritonToLinalgBase &&) = delete;
TritonToLinalgBase& operator=(TritonToLinalgBase &&) = delete;
~TritonToLinalgBase() = default;
/// Returns the command-line argument attached to this pass.
static constexpr ::llvm::StringLiteral getArgumentName() {
return ::llvm::StringLiteral("triton-to-linalg");
}
::llvm::StringRef getArgument() const override { return "triton-to-linalg"; }
::llvm::StringRef getDescription() const override { return "Convert Triton to Linalg dialect"; }
/// Returns the derived pass name.
static constexpr ::llvm::StringLiteral getPassName() {
return ::llvm::StringLiteral("TritonToLinalg");
}
::llvm::StringRef getName() const override { return "TritonToLinalg"; }
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
}
/// Explicitly declare the TypeID for this class. We declare an explicit private
/// instantiation because Pass classes should only be visible by the current
/// library.
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TritonToLinalgBase<DerivedT>)
protected:
};
#undef GEN_PASS_CLASSES
#endif // GEN_PASS_CLASSES
4、文件目录带来的CMakeLists.txt增加
我们的代码在在third_party/spirv/include/Conversion/TritonToLinalg,所以还需要third_party/spirv/include/Conversion/CMakeLists.txt和third_party/spirv/include/CMakeLists.txt2个文件。
二、lib 文件夹变动
1、Pass书写
third_party/spirv/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp。除了include头文件外,我们还需要include.inc
文件,并使用GEN_PASS_DEF_TRITONTOLINALG
宏来展开,TRITONTOLINALG
也就是Pass的名字。Pass的入口都是runOnOperation()
,所以我们需要override
这个方法。MLIR 要求 Pass 必须通过 std::unique_ptr<...> 构造和使用,方便RAII和PassManager.addPass(...)
,所以工厂函数要用return std::make_unique
。具体代码如下所示
#include "spirv/include/Conversion/TritonToLinalg/Passes.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "triton-to-linalg"
namespace mlir::triton::spirv {
#define GEN_PASS_DEF_TRITONTOLINALG
#include "spirv/include/Conversion/TritonToLinalg/Passes.h.inc"
} // namespace mlir::triton::spirv
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::spirv;
namespace {
struct TritonToLinalg
: public mlir::triton::spirv::impl::TritonToLinalgBase<TritonToLinalg> {
void runOnOperation() override {
auto moduleOp = getOperation();
moduleOp.dump();
}
};
} // namespace
namespace mlir::triton::spirv {
std::unique_ptr<OperationPass<ModuleOp>> createTritonToLinalgPass() {
return std::make_unique<TritonToLinalg>();
}
} // namespace mlir::triton::spirv
2、Pass编译
third_party/spirv/lib/Conversion/TritonToLinalg/CMakeLists.txt。设置编译目标,编译.cpp文件,链接上MLIRPass
并DEPENDS
之前的编译td的 TritonToLinalgConversionPassIncGen
即可。具体代码如下所示,之后同文件夹只需要添加.cpp
文件。
add_triton_library(TritonToLinalg
TritonToLinalg.cpp
DEPENDS
TritonToLinalgConversionPassIncGen
LINK_LIBS PUBLIC
MLIRPass
)
3、文件目录带来的CMakeLists.txt增加
我们的代码在在third_party/spirv/lib/Conversion/TritonToLinalg,所以还需要third_party/spirv/lib/Conversion/CMakeLists.txt和third_party/spirv/lib/CMakeLists.txt2个文件。
三、根目录CMakeLists.txt初始化
由于是第一次添加Pass,所以需要在根目录include_directories
和add_subdirectory
,分别为引入include
文件夹内文件和让include
和lib
参与构建
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
add_subdirectory(include)
add_subdirectory(lib)
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18916629