mlir(google/heir) attributes and types

 https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/

(1)什么是attribute

Attributes

Attributes are the mechanism for specifying constant data on operations in places where a variable is never allowed - e.g. the comparison predicate of a arith.cmpi operation, or the underlying value of a arith.constantoperation. Each operation has an attribute dictionary, which associates a set of attribute names to attribute values.

理解:在操作(operations)中需要用到常量时,常量就需要是预先定义号的attributes

例子1:操作arith.cmpi中的$predicate就属于attributes。

operation ::= `arith.cmpi` $predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)

操作arith.cmpi的说明:

按照方式$predicate在同类型的两个数之间比较。被比较的两个数的类型必须相同,类型可以是integersvectors或者tensors,返回的类型分别是i1类型的标量,向量或者张量。方式$predicate只能取如下值之一:

  • equal (mnemonic: "eq"; integer value: 0)
  • not equal (mnemonic: "ne"; integer value: 1)
  • signed less than (mnemonic: "slt"; integer value: 2)
  • signed less than or equal (mnemonic: "sle"; integer value: 3)
  • signed greater than (mnemonic: "sgt"; integer value: 4)
  • signed greater than or equal (mnemonic: "sge"; integer value: 5)
  • unsigned less than (mnemonic: "ult"; integer value: 6)
  • unsigned less than or equal (mnemonic: "ule"; integer value: 7)
  • unsigned greater than (mnemonic: "ugt"; integer value: 8)
  • unsigned greater than or equal (mnemonic: "uge"; integer value: 9)

 例子2:操作arith.constant中的$value就属于attributes

operation ::= `arith.constant` attr-dict $value

操作arith.constant的说明:

创建一个integer常量或者floating-point(浮点数)常量,具体的值和类型由$value指定,比如“%1 = arith.constant 42 : i32”。

(2)什么是type

Types

Every SSA value, such as operation results or block arguments, in MLIR has a type defined by the type system. MLIR has an open type system with no fixed list of types, and there are no restrictions on the abstractions they represent. 

理解:操作的返回值和参数等都需要为其指定类型(type),type可以是MLIR内置的,也可以是自定义的。

例子:操作数(%lhs,%rhs)和返回值(%result)的类型是i64i64就是type

%result = arith.addi %lhs, %rhs : i64

(3)TableGen中attribute和type的定义是相似的

Attributes and Types

The structure for defining Attributes and Types is nearly identical, with only a few differences depending on the context. As such, a majority of this document describes the process for defining both Attributes and Types side-by-side with examples for both. If necessary, a section will explicitly call out any distinct differences.

理解:AttributesTypes的定义方法基本上是一样的。

(4)TableGen中怎么定义新的attribute和type class

Adding a new Attribute or Type definition

As described above, C++ Attribute and Type objects in MLIR are value-typed and essentially function as helpful wrappers around an internal storage object that holds the actual data for the type. Similarly to Operations, Attributes and Types are defined declaratively via TableGen; a generic language with tooling to maintain records of domain-specific information. It is highly recommended that users review the TableGen Programmer’s Reference for an introduction to its syntax and constructs.

Starting the definition of a new attribute or type simply requires adding a specialization for either the AttrDef or TypeDef class respectively. Instances of the classes correspond to unqiue Attribute or Type classes.

理解:Operations一样,AttributesTypes也是通过TableGen定义的。且强烈建议将OperationsAttributesTypes的定义放在不同的td文件中。通过专门化类AttrDef可以定义新的attribute,专门化类TypeDef可以定义新的type

举例1:定义新的Type

// Include the definition of the necessary tablegen constructs for defining
// our types.
include "mlir/IR/AttrTypeBase.td"

// It's common to define a base classes for types in the same dialect. This
// removes the need to pass in the dialect for each type, and can also be used
// to define a few fields ahead of time.
class MyDialect_Type<string name, string typeMnemonic, list<Trait> traits = []>
    : TypeDef<My_Dialect, name, traits> {
  let mnemonic = typeMnemonic;
}

// Here is a simple definition of an "integer" type, with a width parameter.
def My_IntegerType : MyDialect_Type<"Integer", "int"> {
  let summary = "Integer type with arbitrary precision up to a fixed limit";
  let description = [{
    Integer types have a designated bit width.
  }];
  /// Here we defined a single parameter for the type, which is the bitwidth.
  let parameters = (ins "unsigned":$width);

  /// Here we define the textual format of the type declaratively, which will
  /// automatically generate parser and printer logic. This will allow for
  /// instances of the type to be output as, for example:
  ///
  ///    !my.int<10> // a 10-bit integer.
  ///
  let assemblyFormat = "`<` $width `>`";

  /// Indicate that our type will add additional verification to the parameters.
  let genVerifyDecl = 1;
}
MyDialect_Type

举例2:定义新的Attribute

// Include the definition of the necessary tablegen constructs for defining
// our attributes.
include "mlir/IR/AttrTypeBase.td"

// It's common to define a base classes for attributes in the same dialect. This
// removes the need to pass in the dialect for each attribute, and can also be used
// to define a few fields ahead of time.
class MyDialect_Attr<string name, string attrMnemonic, list<Trait> traits = []>
    : AttrDef<My_Dialect, name, traits> {
  let mnemonic = attrMnemonic;
}

// Here is a simple definition of an "integer" attribute, with a type and value parameter.
def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
  let summary = "An Attribute containing a integer value";
  let description = [{
    An integer attribute is a literal attribute that represents an integral
    value of the specified integer type.
  }];
  /// Here we've defined two parameters, one is a "self" type parameter, and the
  /// other is the integer value of the attribute. The self type parameter is
  /// specially handled by the assembly format.
  let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);

  /// Here we've defined a custom builder for the type, that removes the need to pass
  /// in an MLIRContext instance; as it can be infered from the `type`.
  let builders = [
    AttrBuilderWithInferredContext<(ins "Type":$type,
                                        "const APInt &":$value), [{
      return $_get(type.getContext(), type, value);
    }]>
  ];

  /// Here we define the textual format of the attribute declaratively, which will
  /// automatically generate parser and printer logic. This will allow for
  /// instances of the attribute to be output as, for example:
  ///
  ///    #my.int<50> : !my.int<32> // a 32-bit integer of value 50.
  ///
  /// Note that the self type parameter is not included in the assembly format.
  /// Its value is derived from the optional trailing type on all attributes.
  let assemblyFormat = "`<` $value `>`";

  /// Indicate that our attribute will add additional verification to the parameters.
  let genVerifyDecl = 1;

  /// Indicate to the ODS generator that we do not want the default builders,
  /// as we have defined our own simpler ones.
  let skipDefaultBuilders = 1;
}
MyDialect_Attr

(5)TableGen attribute和type class的名字

Class Name

The name of the C++ class which gets generated defaults to <classParamName>Attr or <classParamName>Type for attributes and types respectively. In the examples above, this was the name template parameter that was provided to MyDialect_Attr and MyDialect_Type. For the definitions we added above, we would get C++ classes named IntegerType and IntegerAttr respectively. This can be explicitly overridden via the cppClassName field.

理解:attribute类和type类的默认类名是<classParamName>Attr<classParamName>Type。在上面的例子中,<classParamName>就是提供给class MyDialect_Attr”和“class MyDialect_Type”的模板参数string name,C++源码中的类名就是在此基础上添加了后缀“Attr/Type”。因此,def My_IntegerAttr : MyDialect_Attr<"Integer", "int">{...}得到了新的attribute class并且在C++中的类名是IntegerAttrdef My_IntegerType : MyDialect_Type<"Integer", "int">{...}得到了新的type class并且在C++中的类名是IntegerType。当然,C++中类名也可以更改,更改方法是在TableGen(文件*.td)中用let cppClassName”覆盖。

(6)TableGen attribute和type class的介绍信息

Documentation 

The summary and description fields allow for providing user documentation for the attribute or type. The summary field expects a simple single-line string, with the description field used for long and extensive documentation. This documentation can be used to generate markdown documentation for the dialect and is used by upstream MLIR dialects.

理解:在td中定义新AttributeType类的实例中的字段summary是一行字符串,字段description是多行字符串,这两个都是对新类的说明信息,用mlir-tblgen工具可以基于这两个字段生成可阅读的markdownmd)文件。

(7)TableGen attribute和type class的助记符

Mnemonic 

The mnemonic field, i.e. the template parameters attrMnemonic and typeMnemonic we specified above, are used to specify a name for use during parsing. This allows for more easily dispatching to the current attribute or type class when parsing IR. This field is generally optional, and custom parsing/printing logic can be added without defining it, though most classes will want to take advantage of the convenience it provides. This is why we added it as a template parameter in the examples above.

理解:在td中定义新的AttributeType类的实例中的字段mnemonic是助记符,是在解析MLIR语法时使用的字串。在上面的例子中,mnemonic是传入class MyDialect_Attr<string name, string attrMnemonic, list<Trait> traits = []>中的模板参数string attrMnemonic”,mnemonic是传入class MyDialect_Type<string name, string typeMnemonic, list<Trait> traits = []>中的模板参数string typeMnemonic”。也就是说类名IntegerAttrMLIR中用助记符“int”表示,IntegerTypeMLIR中的助记符也是“int”表示的。

(8)TableGen attribute和type class的参数

Parameters 

The parameters field is a variable length list containing the attribute or type’s parameters. If no parameters are specified (the default), this type is considered a singleton type (meaning there is only one possible instance). Parameters in this list take the form: "c++Type":$paramName.

理解:attribute类和type类的参数是一个可变长度列表,列表长度可以是0。列表中的单个元素的形式为"c++Type":$paramName。

例子:

code/heir-private/lib/Dialect/LWE/IR/LWEAttributes.td

let parameters = (ins "unsigned":$cleartext_start, "unsigned":$cleartext_bitwidth);
let parameters = (ins "unsigned":$cleartext_bitwidth);
let parameters = (ins "IntegerAttr": $cmod, "unsigned":$dimension);
let parameters = (ins DefaultValuedParameter<"unsigned", "2">:$dimension,
 "::mlir::heir::polynomial::RingAttr":$ring);

code/heir-private/lib/Dialect/Polynomial/IR/PolynomialAttributes.td

class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
    : AttrDef<Polynomial_Dialect, name, traits> {
  let mnemonic = attrMnemonic;
}
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring", [OpAsmAttrInterface]> {
  let summary = "...”;
  let description =”...”;
  ...
}

8.1.默认参数,也就是可以省略的参数,省略时的值就是默认值

DefaultValuedParameter

An optional parameter can also be specified with DefaultValuedParameter, which specifies that a parameter should be omitted when it is equal to some given value.

理解:在attribute和type的参数列表中,如果列表中某个元素类型是DefaultValuedParameter,该元素参数有一个默认值,这种参数写的时候可以省略,省略时候就表示这个参数就取默认值。

例子:

let parameters = (ins DefaultValuedParameter<"Optional<int>", "5">:$a)
let mnemonic = "default_valued";
let assemblyFormat = "(`<` $a^ `>`)?";

!test.default_valued     // a = 5
!test.default_valued<10> // a = 10

(9)汇编格式,也就是TableGen attribute和type class在mlir中的语法格式

Using assemblyFormat

Attributes and types defined in ODS with a mnemonic can define an assemblyFormat to declaratively describe custom parsers and printers. The assembly format consists of literals, variables, and directives.

  • A literal is a keyword or valid punctuation enclosed in backticks, e.g. `keyword` or `<`.
  •  A variable is a parameter name preceded by a dollar sign, e.g. $param0, which captures one attribute or type parameter.
  •  A directive is a keyword followed by an optional argument list that defines special parser and printer behaviour.

理解:如果在新AttributesTypes的定义td源码中使用了字段mnemonic,那么就可以在td源码中利用字段assemblyFormat声明式地定义自定义的解析器(parsers)和打印器(printers)的格式。字段assemblyFormat的值由字面量(literals)、变量(variables)和指令(directives)组成。字面量是用反引号括起来的关键字或有效标点符号,变量是带符号$的参数名(参数名取之attributetype的参数列表),指令一般都不常用。

例子:

// An example type with an assembly format.
def MyType : TypeDef<My_Dialect, "MyType"> {
  // Define a mnemonic to allow the dialect's parser hook to call into the
  // generated parser.
  let mnemonic = "my_type";

  // Define two parameters whose C++ types are indicated in string literals.
  let parameters = (ins "int":$count, "AffineMap":$map);

  // Define the assembly format. Surround the format with less `<` and greater
  // `>` so that MLIR's printer uses the pretty format.
  let assemblyFormat = "`<` $count `,` `map` `=` $map `>`";
}

!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>>

9.1.汇编格式中的TableGen 指令

Assembly Format Directives 

Attribute and type assembly formats have the following directives:

  • l params: capture all parameters of an attribute or type.
  • l qualified: mark a parameter to be printed with its leading dialect and mnemonic.
  • l struct: generate a “struct-like” parser and printer for a list of key-value pairs.
  • l custom: dispatch a call to user-define parser and printer functions
  • l ref: in a custom directive, references a previously bound variable

params Directive 

This directive is used to refer to all parameters of an attribute or type, except for the attribute self type (which is handled separately from normal parameters). When used as a top-level directive, params generates a parser and printer for a comma-separated list of the parameters

struct Directive 

The struct directive accepts a list of variables to capture and will generate a parser and printer for a comma-separated list of key-value pairs. If an optional parameter is included in the struct, it can be elided. The variables are printed in the order they are specified in the argument list but can be parsed in any order.

理解:在定义AttributeTypetd源码中,可以使用指令:params,qualified,struct,customref。其中,params指令是对attributetype的参数列表的引用,当params作为顶级指令的时,params可为用逗号分隔的参数列表生成解析器(parser)和打印器(printer)。另外,params也可以作为其他指令(比如struct)的传入参数,当params作为传入参数时候表示将可变长度的变量列表。struct指令需要将一个变量列表作为它的传入参数,并为这个用逗号分割的变量列表生成解析器(parser)和打印器(printer),解析器可以对参数列表中的任意排列做解析,而打印前只会按照参数列表中的顺序做打印。

例子1:

def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
  let parameters = (ins "int":$a, "int":$b);
  let mnemonic = "pair";
  let assemblyFormat = "`<` params `>`";
}

IR中,打印器输出的MLIR代码和解析器能接受的MLIR代码的形式如下:

!my_dialect.pair<42, 24>

例子2:

def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
  let parameters = (ins StringRefParameter<>:$sym_name, "int":$a, "int":$b, "int":$c);
  let mnemonic = "struct";
  let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`";
}

IR中,打印器输出的MLIR代码是第1种,能解释器能接受的MLIR代码可以是第1种和第2种:

第1种:!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3>
第2种:!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1>

(10)在TableGen attribute/type class的定义中添加自定义的原生C++代码。

Extra declarations ¶

The declarative Attribute and Type definitions try to auto-generate as much logic and methods as possible. With that said, there will always be long-tail cases that won’t be covered. For such cases, extraClassDeclaration and extraClassDefinition can be used. Code within the extraClassDeclaration field will be copied literally to the generated C++ Attribute or Type class. Code within extraClassDefinition will be added to the generated source file inside the class’s C++ namespace. The substitution $cppClass will be replaced by the Attribute or Type’s C++ class name. Note:that these are mechanisms intended for long-tail cases by power users; for not-yet-implemented widely-applicable cases, improving the infrastructure is preferable.

理解:利用tablegen声明式地定义新的AttributeType,有很多的逻辑和方法对应的C++源码都是自动生成的,但还是会有例外,会需要在td源码中为新定义的attributetype添加原生的C++源码,这个时候就可以使用extraClassDeclaration

例子:

class LWE_EncodingAttr<string attrName, string attrMnemonic, list<Trait> traits = []>
    : AttrDef<LWE_Dialect, attrName, traits # [
    // All encoding attributes are required to be compatible with a tensor
    // with an element type relevant to that encoding.
    DeclareAttrInterfaceMethods<VerifiableTensorEncoding>,
    OpAsmAttrInterface
]> {
  let mnemonic = attrMnemonic;
  let assemblyFormat = "`<` struct(params) `>`";

  let extraClassDeclaration = [{
    // OpAsmAttrInterface methods.
    ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const {
      os << "}] # attrMnemonic # [{";
      return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
    }
  }];
}

class LWE_EncodingAttrWithScalingFactor<string attrName, string attrMnemonic, list<Trait> traits = []>
    : LWE_EncodingAttr<attrName, attrMnemonic, traits> {
  let parameters = (ins
    "unsigned":$cleartext_start,
    "unsigned":$cleartext_bitwidth
  );
}

def LWE_BitFieldEncoding
  : LWE_EncodingAttrWithScalingFactor<"BitFieldEncoding", "bit_field_encoding"> {
  let summary = "An attribute describing encoded LWE plaintexts using bit fields.";
  let description = [{
    A bit field encoding of an integer describes which contiguous region
    of bits a small integer occupies within a larger integer.
  }];
}
lib/Dialect/LWE/IR/LWEAttributes.td

利用mlir-tblgen将td文件转为相应的C++头文件。

mlir-tblgen -I ./llvm-project/mlir/include/ -I ./heir-private -gen-attrdef-decls ./heir-private/lib/Dialect/LWE/IR/LWEAttributes.td -o LWEAttributes.td.h.inc

得到的C++源码文件LWEAttributes.td.h.inc的部分内容:

namespace mlir {
namespace heir {
namespace lwe {

class BitFieldEncodingAttr;

class BitFieldEncodingAttr : public ::mlir::Attribute::AttrBase<BitFieldEncodingAttr, ::mlir::Attribute, detail::BitFieldEncodingAttrStorage, ::mlir::VerifiableTensorEncoding::Trait, ::mlir::OpAsmAttrInterface::Trait> {
public:
  using Base::Base;
  // OpAsmAttrInterface methods.
  ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const {
    os << "bit_field_encoding";
    return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
  }
  static constexpr ::llvm::StringLiteral name = "lwe.bit_field_encoding";
  static constexpr ::llvm::StringLiteral dialectName = "lwe";
  static BitFieldEncodingAttr get(::mlir::MLIRContext *context, unsigned cleartext_start, unsigned cleartext_bitwidth);
  static constexpr ::llvm::StringLiteral getMnemonic() {
    return {"bit_field_encoding"};
  }

  static ::mlir::Attribute parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType);
  void print(::mlir::AsmPrinter &odsPrinter) const;
  unsigned getCleartextStart() const;
  unsigned getCleartextBitwidth() const;
  ::llvm::LogicalResult verifyEncoding(::mlir::ArrayRef<int64_t> shape, ::mlir::Type elementType, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const;
};

} // namespace lwe
} // namespace heir
} // namespace mlir
LWEAttributes.td.h.inc

关于C++源码文件LWEAttributes.td.h.inc,有几个C++语法补充:

补充1:以下代码是一个类的前向声明。它的作用是告诉编译器存在一个名字为BitFieldEncodingAttr的类,但并不提供该类的完整的定义。类的完整定义在其他地方。

class BitFieldEncodingAttr;

补充2:以下代码中,pubic表示继承是公用的,基类的公共成员在派生类中也是公共的。::mlir::Attribute::AttrBase是一个模板类,”<...>”中是它的模板参数。

class BitFieldEncodingAttr : public ::mlir::Attribute::AttrBase<BitFieldEncodingAttr, ::mlir::Attribute, detail::BitFieldEncodingAttrStorage, ::mlir::VerifiableTensorEncoding::Trait, ::mlir::OpAsmAttrInterface::Trait> {....}

补充3:以下代码中,使用using关键字引入基类Base的构造函数Base::Base

{
public:
  using Base::Base;
}

(11).验证

 Verification 

If the genVerifyDecl field is set, additional verification methods are generated on the class.

static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError, parameters...)

These methods are used to verify the parameters provided to the attribute or type class on construction, and emit any necessary diagnostics. This method is automatically invoked from the builders of the attribute or type class.

AttrOrType getChecked(function_ref<InFlightDiagnostic()> emitError, parameters...)

As noted in the Builders section, these methods are companions to get builders that are failable. If the verify invocation fails when these methods are called, they return nullptr instead of asserting.

理解:如果TableGen attribute/type class中的字段genVerifyDecl的值为1,那么TableGen工具为该attribute/type class创建的C++源码文件中会自动创建函数“static LogicalResault MyNameAttr::verify(function_ref<InFlightDiagnostic()> emitError, parameters...)”,这个函数的传入参数包括了TableGen attribute/type class定义时候字段let parameters=...”的值。

例子:

def LWE_PlaintextSpaceAttr : AttrDef<LWE_Dialect, "PlaintextSpace"> {
  let mnemonic = "plaintext_space";
  let description = [{
    An attribute describing the plaintext space and the transformation 
  }];

  let parameters = (ins
    "::mlir::heir::polynomial::RingAttr":$ring,
    AnyPlaintextEncodingAttr:$encoding
  );

  let assemblyFormat = "`<` struct(params) `>`";

  let genVerifyDecl = 1;
}
lib\Dialect\LWE\IR\NewLWEAttributes.td
#ifdef GET_ATTRDEF_CLASSES
#undef GET_ATTRDEF_CLASSES
namespace mlir {
namespace heir {
namespace lwe {
......
class PlaintextSpaceAttr;
......
class PlaintextSpaceAttr : public ::mlir::Attribute::AttrBase<PlaintextSpaceAttr, ::mlir::Attribute, detail::PlaintextSpaceAttrStorage> {
public:
  using Base::Base;
  static constexpr ::llvm::StringLiteral name = "lwe.plaintext_space";
  static constexpr ::llvm::StringLiteral dialectName = "lwe";
  using Base::getChecked;
  static PlaintextSpaceAttr get(::mlir::MLIRContext *context, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding);
  static PlaintextSpaceAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding);
  static ::llvm::LogicalResult verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding);
  static ::llvm::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding);
  static ::llvm::LogicalResult verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding);
  static constexpr ::llvm::StringLiteral getMnemonic() {
    return {"plaintext_space"};
  }

  static ::mlir::Attribute parse(::mlir::AsmParser &odsParser, ::mlir::Type odsType);
  void print(::mlir::AsmPrinter &odsPrinter) const;
  ::mlir::heir::polynomial::RingAttr getRing() const;
  Attribute getEncoding() const;
};
#endif  // GET_ATTRDEF_CLASSES
LWEAttributes.td.h.inc
#ifdef GET_ATTRDEF_LIST
#undef GET_ATTRDEF_LIST

::mlir::heir::lwe::PlaintextSpaceAttr,
......
#endif  // GET_ATTRDEF_LIST

#ifdef GET_ATTRDEF_CLASSES
#undef GET_ATTRDEF_CLASSES
......
::llvm::LogicalResult PlaintextSpaceAttr::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::heir::polynomial::RingAttr ring, Attribute encoding) {
  if (::mlir::failed(verifyInvariantsImpl(emitError, ring, encoding)))
    return ::mlir::failure();
  if (::mlir::failed(verify(emitError, ring, encoding)))
    return ::mlir::failure();
  return ::mlir::success();
}

#endif  // GET_ATTRDEF_CLASSES
LWEAttributes.td.cpp.inc
LogicalResult PlaintextSpaceAttr::verify(
    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
    mlir::heir::polynomial::RingAttr ring, Attribute encoding) {
  if (mlir::isa<FullCRTPackingEncodingAttr>(encoding)) {
    // For full CRT packing, the ring must be of the form x^n + 1 and the
    // modulus must be 1 mod n.
    auto polyMod = ring.getPolynomialModulus();
    auto poly = polyMod.getPolynomial();
    auto polyTerms = poly.getTerms();
    if (polyTerms.size() != 2) {
      return emitError() << "polynomial modulus must be of the form x^n + 1, "
                         << "but found " << polyMod << "\n";
    }
    const auto& constantTerm = polyTerms[0];
    const auto& constantCoeff = constantTerm.getCoefficient();
    if (!(constantTerm.getExponent().isZero() && constantCoeff.isOne() &&
          polyTerms[1].getCoefficient().isOne())) {
      return emitError() << "polynomial modulus must be of the form x^n + 1, "
                         << "but found " << polyMod << "\n";
    }
    // Check that the modulus is 1 mod n.
    auto modCoeffTy =
        llvm::dyn_cast<mod_arith::ModArithType>(ring.getCoefficientType());
    if (modCoeffTy) {
      APInt modulus = modCoeffTy.getModulus().getValue();
      unsigned n = poly.getDegree();
      if (!modulus.urem(APInt(modulus.getBitWidth(), n)).isOne()) {
        return emitError()
               << "modulus must be 1 mod n for full CRT packing, mod = "
               << modulus.getZExtValue() << " n = " << n << "\n";
      }
    }
  }

  return success();
}
lib\Dialect\LWE\IR\LWEAttributes.cpp

 查漏补缺:

(1)前面我们知道通过TableGen AttrDef class可以定义新的attribute,这里的TableGen AttrDef class的定义在文件code/llvm-project/mlir/include/mlir/IR/AttrTypeBase.td

// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
              string baseCppClass = "::mlir::Type">
    : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
      AttrOrTypeDef<"Type", name, traits, baseCppClass> {
  // The name of the C++ Type class.
  string cppClassName = name # "Type";

  // Make it possible to use such type as parameters for other types.
  string cppType = dialect.cppNamespace # "::" # cppClassName;

  // The unique type name.
  string typeName = dialect.name # "." # mnemonic;

  // A constant builder provided when the type has no parameters.
  let builderCall = !if(!empty(parameters),
                           "$_builder.getType<" # cppType # ">()",
                           "");

  // The predicate for when this def is used as a constraint.
  let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}
code/llvm-project/mlir/include/mlir/IR/AttrTypeBase.td:TypeDef

(2)阅读文件code/heir-private/lib/Dialect/LWE/IR/NewLWETypes.td,可以找到关于AnyTypeOf的使用,查找AnyTypeOf的源码定义,可以知道它是一种类型约束,因它是TableGen TypeConstraint class的子类。

//===----------------------------------------------------------------------===//
// Type definitions
//===----------------------------------------------------------------------===//

// A type, carries type constraints.
class Type<Pred condition, string descr = "",
           string cppType = "::mlir::Type"> :
    TypeConstraint<condition, descr, cppType> {
  string description = "";
  string builderCall = "";
}

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
                string cppType = "::mlir::Type"> : Type<
    // Satisfy any of the allowed types' conditions.
    Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
    !if(!eq(summary, ""),
        !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
        summary),
    cppType> {
  list<Type> allowedTypes = allowedTypeList;
}
code/llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td:AnyTypeOf
// A base class for all types in this dialect
class LWE_Type<string name, string typeMnemonic, list<Trait> traits = []>
    : TypeDef<LWE_Dialect, name, traits # [OpAsmTypeInterface]> {
  let mnemonic = typeMnemonic;
  let assemblyFormat = "`<` struct(params) `>`";

  string asmName = ?;
  string aliasName = "";
  string aliasSuffix = "";
  let extraClassDeclaration = [{
    // OpAsmTypeInterface method
    void getAsmName(::mlir::OpAsmSetNameFn setNameFn) const {
      setNameFn("}] # asmName # [{");
    }

  }] # !if(!ne(aliasName, ""), [{
    ::mlir::OpAsmDialectInterface::AliasResult getAlias(::llvm::raw_ostream &os) const {
      os << "}] # aliasName # [{";
      }] # aliasSuffix # [{
      return ::mlir::OpAsmDialectInterface::AliasResult::FinalAlias;
    }
  }], "");
}

def NewLWESecretKey : LWE_Type<"NewLWESecretKey", "new_lwe_secret_key"> {
  let summary = "A secret key for LWE";
  let parameters = (ins
    "KeyAttr":$key,
    "::mlir::heir::polynomial::RingAttr":$ring
  );
  let asmName = "sk";
  let aliasName= "skey";
  let aliasSuffix = [{ getRing().getAliasSuffix(os); }];
}

def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> {
  let summary = "A public key for LWE";
  let parameters = (ins
    "KeyAttr":$key,
    "::mlir::heir::polynomial::RingAttr":$ring
  );
  let asmName = "pk";
  let aliasName = "pkey";
  let aliasSuffix = [{ getRing().getAliasSuffix(os); }];
}

def NewLWESecretOrPublicKey : AnyTypeOf<[NewLWESecretKey, NewLWEPublicKey]>;
code/heir-private/lib/Dialect/LWE/IR/NewLWETypes.td
// LWE Operations are always Pure by design
class LWE_Op<string mnemonic, list<Trait> traits = []> :
        Op<LWE_Dialect, mnemonic,  traits # [Pure]> {
  let cppNamespace = "::mlir::heir::lwe";
  let assemblyFormat = [{
    operands attr-dict `:`  functional-type(operands, results)
  }];
}

def LWE_RLWEEncryptOp : LWE_Op<"rlwe_encrypt", [
    NewEncodingsMatch<"input", "NewLWEPlaintextType", "output", "NewLWECiphertextType">]> {
  let summary = "Encrypt an RLWE plaintext to a RLWE ciphertext";
  let description = [{
    Encrypt an RLWE plaintext to yield a RLWE ciphertext.
  }];

  let arguments = (ins
    NewLWEPlaintext:$input,
    NewLWESecretOrPublicKey:$key
  );
  let results = (outs NewLWECiphertext:$output);
  let hasVerifier = 1;
}
code/heir-private/lib/Dialect/LWE/IR/LWEOps.td
#ifdef GET_OP_LIST
#undef GET_OP_LIST

......
::mlir::heir::lwe::RLWEEncryptOp,
......
#endif  // GET_OP_LIST

#ifdef GET_OP_CLASSES
#undef GET_OP_CLASSES

......
static ::llvm::LogicalResult __mlir_ods_local_type_constraint_LWEOps9(
    ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind,
    unsigned valueIndex) {
  if (!(((::llvm::isa<::mlir::heir::lwe::NewLWESecretKeyType>(type))) || ((::llvm::isa<::mlir::heir::lwe::NewLWEPublicKeyType>(type))))) {
    return op->emitOpError(valueKind) << " #" << valueIndex
        << " must be A secret key for LWE or A public key for LWE, but got " << type;
  }
  return ::mlir::success();
}

::llvm::LogicalResult RLWEEncryptOp::verifyInvariantsImpl() {
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSOperands(0);

    for (auto v : valueGroup0) {
      if (::mlir::failed(__mlir_ods_local_type_constraint_LWEOps7(*this, v.getType(), "operand", index++)))
        return ::mlir::failure();
    }
    auto valueGroup1 = getODSOperands(1);

    for (auto v : valueGroup1) {
      if (::mlir::failed(__mlir_ods_local_type_constraint_LWEOps9(*this, v.getType(), "operand", index++)))
        return ::mlir::failure();
    }
  }
  {
    unsigned index = 0; (void)index;
    auto valueGroup0 = getODSResults(0);

    for (auto v : valueGroup0) {
      if (::mlir::failed(__mlir_ods_local_type_constraint_LWEOps6(*this, v.getType(), "result", index++)))
        return ::mlir::failure();
    }
  }
  if (!((std::equal_to<>()(::llvm::cast<lwe::NewLWEPlaintextType>((*this->getODSOperands(0).begin()).getType()).getPlaintextSpace().getEncoding(), ::llvm::cast<lwe::NewLWECiphertextType>((*this->getODSResults(0).begin()).getType()).getPlaintextSpace().getEncoding()))))
    return emitOpError("failed to verify that the first arg's type's encoding matches the given encoding");
  return ::mlir::success();
}

::llvm::LogicalResult RLWEEncryptOp::verifyInvariants() {
  if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
    return ::mlir::success();
  return ::mlir::failure();
}


......

#endif  // GET_OP_CLASSES
LWEOps.td.cpp.inc
LWEOps.td.cpp.inc
  1 namespace mlir {
  2 namespace heir {
  3 namespace lwe {
  4 class RLWEEncryptOp;
  5 } // namespace lwe
  6 } // namespace heir
  7 } // namespace mlir
  8 
  9 #ifdef GET_OP_CLASSES
 10 #undef GET_OP_CLASSES
 11 
 12 namespace mlir {
 13 namespace heir {
 14 namespace lwe {
 15 
 16 //===----------------------------------------------------------------------===//
 17 // ::mlir::heir::lwe::RLWEEncryptOp declarations
 18 //===----------------------------------------------------------------------===//
 19 
 20 namespace detail {
 21 class RLWEEncryptOpGenericAdaptorBase {
 22 public:
 23 protected:
 24   ::mlir::DictionaryAttr odsAttrs;
 25   ::std::optional<::mlir::OperationName> odsOpName;
 26   ::mlir::RegionRange odsRegions;
 27 public:
 28   RLWEEncryptOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = {}, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : odsAttrs(attrs), odsRegions(regions) {  if (odsAttrs)
 29       odsOpName.emplace("lwe.rlwe_encrypt", odsAttrs.getContext());
 30   }
 31 
 32   RLWEEncryptOpGenericAdaptorBase(::mlir::Operation *op) : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), odsRegions(op->getRegions()) {}
 33 
 34   std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize) {
 35     return {index, 1};
 36   }
 37 
 38   ::mlir::DictionaryAttr getAttributes() {
 39     return odsAttrs;
 40   }
 41 
 42 };
 43 } // namespace detail
 44 template <typename RangeT>
 45 class RLWEEncryptOpGenericAdaptor : public detail::RLWEEncryptOpGenericAdaptorBase {
 46   using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
 47   using Base = detail::RLWEEncryptOpGenericAdaptorBase;
 48 public:
 49   RLWEEncryptOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = {}, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {}
 50 
 51   RLWEEncryptOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : RLWEEncryptOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {}
 52 
 53   RLWEEncryptOpGenericAdaptor(RangeT values, const RLWEEncryptOpGenericAdaptorBase &base) : Base(base), odsOperands(values) {}
 54 
 55   template <typename LateInst = RLWEEncryptOp, typename = std::enable_if_t<std::is_same_v<LateInst, RLWEEncryptOp>>>
 56   RLWEEncryptOpGenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {}
 57 
 58   std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) {
 59     return Base::getODSOperandIndexAndLength(index, odsOperands.size());
 60   }
 61 
 62   RangeT getODSOperands(unsigned index) {
 63     auto valueRange = getODSOperandIndexAndLength(index);
 64     return {std::next(odsOperands.begin(), valueRange.first),
 65              std::next(odsOperands.begin(), valueRange.first + valueRange.second)};
 66   }
 67 
 68   ValueT getInput() {
 69     return (*getODSOperands(0).begin());
 70   }
 71 
 72   ValueT getKey() {
 73     return (*getODSOperands(1).begin());
 74   }
 75 
 76   RangeT getOperands() {
 77     return odsOperands;
 78   }
 79 
 80 private:
 81   RangeT odsOperands;
 82 };
 83 class RLWEEncryptOpAdaptor : public RLWEEncryptOpGenericAdaptor<::mlir::ValueRange> {
 84 public:
 85   using RLWEEncryptOpGenericAdaptor::RLWEEncryptOpGenericAdaptor;
 86   RLWEEncryptOpAdaptor(RLWEEncryptOp op);
 87 
 88   ::llvm::LogicalResult verify(::mlir::Location loc);
 89 };
 90 class RLWEEncryptOp : public ::mlir::Op<RLWEEncryptOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::heir::lwe::NewLWECiphertextType>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> {
 91 public:
 92   using Op::Op;
 93   using Op::print;
 94   using Adaptor = RLWEEncryptOpAdaptor;
 95   template <typename RangeT>
 96   using GenericAdaptor = RLWEEncryptOpGenericAdaptor<RangeT>;
 97   using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>;
 98   static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
 99     return {};
100   }
101 
102   static constexpr ::llvm::StringLiteral getOperationName() {
103     return ::llvm::StringLiteral("lwe.rlwe_encrypt");
104   }
105 
106   std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) {
107     return {index, 1};
108   }
109 
110   ::mlir::Operation::operand_range getODSOperands(unsigned index) {
111     auto valueRange = getODSOperandIndexAndLength(index);
112     return {std::next(getOperation()->operand_begin(), valueRange.first),
113              std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)};
114   }
115 
116   ::mlir::TypedValue<::mlir::heir::lwe::NewLWEPlaintextType> getInput() {
117     return ::llvm::cast<::mlir::TypedValue<::mlir::heir::lwe::NewLWEPlaintextType>>(*getODSOperands(0).begin());
118   }
119 
120   ::mlir::TypedValue<::mlir::Type> getKey() {
121     return ::llvm::cast<::mlir::TypedValue<::mlir::Type>>(*getODSOperands(1).begin());
122   }
123 
124   ::mlir::OpOperand &getInputMutable() {
125     auto range = getODSOperandIndexAndLength(0);
126     return getOperation()->getOpOperand(range.first);
127   }
128 
129   ::mlir::OpOperand &getKeyMutable() {
130     auto range = getODSOperandIndexAndLength(1);
131     return getOperation()->getOpOperand(range.first);
132   }
133 
134   std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index) {
135     return {index, 1};
136   }
137 
138   ::mlir::Operation::result_range getODSResults(unsigned index) {
139     auto valueRange = getODSResultIndexAndLength(index);
140     return {std::next(getOperation()->result_begin(), valueRange.first),
141              std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)};
142   }
143 
144   ::mlir::TypedValue<::mlir::heir::lwe::NewLWECiphertextType> getOutput() {
145     return ::llvm::cast<::mlir::TypedValue<::mlir::heir::lwe::NewLWECiphertextType>>(*getODSResults(0).begin());
146   }
147 
148   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type output, ::mlir::Value input, ::mlir::Value key);
149   static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input, ::mlir::Value key);
150   static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
151   ::llvm::LogicalResult verifyInvariantsImpl();
152   ::llvm::LogicalResult verifyInvariants();
153   ::llvm::LogicalResult verify();
154   static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
155   void print(::mlir::OpAsmPrinter &_odsPrinter);
156   void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
157 public:
158 };
159 } // namespace lwe
160 } // namespace heir
161 } // namespace mlir
162 MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::heir::lwe::RLWEEncryptOp)
163 
164 #endif  // GET_OP_CLASSES
LWEOps.td.h.inc

 

posted on 2025-04-20 15:00  LiveWithACat  阅读(83)  评论(0)    收藏  举报