Files
Twice 2874480930 [MLIR][Python] Migrate result(infer_type=True) to a new field specifier (#191849)
Currrently the signature of `result(..)` is:
```python
result(*, infer_type: bool = False, default_factory: Callable[[], Any] | None = None, kw_only: bool = False) -> Result
```

so when users use `result(infer_type=True)`, the type checkers will
still get `kw_only=False` (from the signature), but actually the
`kw_only` should be `True` (it should follow the value of `infer_type`).
users can use `result(infer_type=True, kw_only=True)` but it's
unnecessarily verbose.

So it may introduce an incompatibility when we start to use
`dataclass_transform`. currently it's fine because we just don't use
`dataclass_transform`. But when we use, we may require a breaking
change.

This PR migrates such use to a new field specifier named
`infer_result()`.
2026-04-14 10:45:28 +08:00

897 lines
32 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# RUN: %PYTHON %s 2>&1 | FileCheck %s
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects.ext import *
from mlir import ir
from typing import Any, Optional, Sequence, TypeVar, Union
import sys
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK: TEST: testMyInt
@run
def testMyInt():
class MyInt(Dialect, name="myint"):
pass
i32 = IntegerType[32]
class ConstantOp(MyInt.Operation, name="constant"):
value: IntegerAttr
cst: Result[i32] = infer_result()
class AddOp(Operation, dialect=MyInt, name="add"):
lhs: Operand[i32]
rhs: Operand[i32]
res: Result[i32] = infer_result()
# CHECK: irdl.dialect @myint {
# CHECK: irdl.operation @constant {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"value" = %0}
# CHECK: %1 = irdl.is i32
# CHECK: irdl.results(cst: %1)
# CHECK: }
# CHECK: irdl.operation @add {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(lhs: %0, rhs: %0)
# CHECK: irdl.results(res: %0)
# CHECK: }
# CHECK: }
with Context():
MyInt.load()
print(MyInt._mlir_module)
# CHECK: ['constant', 'add']
print([i._op_name for i in MyInt.operations])
i32 = IntegerType.get_signless(32)
with Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
two = ConstantOp(IntegerAttr.get(i32, 2))
three = ConstantOp(IntegerAttr.get(i32, 3))
add1 = AddOp(two, three)
add2 = AddOp(add1, two)
add3 = AddOp(add2, three)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
# CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
# CHECK: %3 = "myint.add"(%2, %0) : (i32, i32) -> i32
# CHECK: %4 = "myint.add"(%3, %1) : (i32, i32) -> i32
print(module)
assert module.operation.verify()
# CHECK: AddOp
print(type(add1).__name__)
# CHECK: ConstantOp
print(type(two).__name__)
# CHECK: myint.add
print(add1.OPERATION_NAME)
# CHECK: None
print(add1._ODS_OPERAND_SEGMENTS)
# CHECK: None
print(add1._ODS_RESULT_SEGMENTS)
# CHECK: (0, True)
print(add1._ODS_REGIONS)
# CHECK: %0 = "myint.constant"() {value = 2 : i32} : () -> i32
print(add1.lhs.owner)
# CHECK: %1 = "myint.constant"() {value = 3 : i32} : () -> i32
print(add1.rhs.owner)
# CHECK: 2 : i32
print(two.value)
# CHECK: OpResult(%0
print(two.cst)
# CHECK: (self, /, lhs, rhs, *, res=None, loc=None, ip=None)
print(AddOp.__init__.__signature__)
# CHECK: (self, /, value, *, cst=None, loc=None, ip=None)
print(ConstantOp.__init__.__signature__)
# CHECK: True
print(issubclass(AddOp.Adaptor, OpAdaptor))
adaptor1 = AddOp.Adaptor(list(add1.operands), add1)
# CHECK: myint.add
print(adaptor1.OPERATION_NAME)
# CHECK: OpResult(%0 = "myint.constant"() {value = 2 : i32} : () -> i32)
print(adaptor1.lhs)
# CHECK: OpResult(%1 = "myint.constant"() {value = 3 : i32} : () -> i32)
print(adaptor1.rhs)
# CHECK: TEST: testExtDialect
@run
def testExtDialect():
class Test(Dialect, name="ext_test"):
pass
i32 = IntegerType[32]
class ConstraintOp(Test.Operation, name="constraint"):
a: Operand[i32 | IntegerType[64]]
b: Operand[Any]
# Here we use `F32Type[()]` instead of just `F32Type`
# because of an existing issue in IRDL implementation
# where `irdl.base` cannot exist in `irdl.any_of`.
c: Operand[F32Type[()] | i32]
d: Operand[Any]
x: IntegerAttr
y: FloatAttr
class OptionalOp(Test.Operation, name="optional"):
out1: Result[i32]
out3: Result[i32]
a: Operand[i32]
out2: Result[i32] | None = None
b: Optional[Operand[i32]] = None
class Optional2Op(Test.Operation, name="optional2"):
b: Optional[Result[i32]] = None
a: Optional[Operand[i32]] = None
class VariadicOp(Test.Operation, name="variadic"):
out1: Sequence[Result[i32]]
out2: Sequence[Result[i32]]
out4: Result[i32]
a: Operand[i32]
c: Sequence[Operand[i32]]
out3: Optional[Result[i32]] = None
b: Optional[Operand[i32]] = None
class Variadic2Op(Test.Operation, name="variadic2"):
b: Sequence[Result[i32]]
a: Sequence[Operand[i32]]
class MixedOpBase(Test.Operation):
out: Result[i32]
in1: Operand[i32]
class MixedOp(MixedOpBase, name="mixed"):
in2: IntegerAttr
in4: IntegerAttr
in5: Operand[i32]
in3: Optional[Operand[i32]] = None
T = TypeVar("T")
U = TypeVar("U", bound=IntegerType[32] | IntegerType[64])
V = TypeVar("V", bound=Union[IntegerType[8], IntegerType[16]])
class TypeVarOp(Test.Operation, name="type_var"):
in1: Operand[T]
in2: Operand[T]
in3: Operand[U]
in4: Operand[U | V]
in5: Operand[V]
class OptionalButNotKeywordOp(Test.Operation, name="optional_but_not_keyword"):
a: Operand[i32]
b: Optional[Operand[i32]]
c: Operand[i32]
# CHECK: irdl.dialect @ext_test {
# CHECK: irdl.operation @constraint {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.any
# CHECK: %4 = irdl.is f32
# CHECK: %5 = irdl.any_of(%4, %0)
# CHECK: %6 = irdl.any
# CHECK: irdl.operands(a: %2, b: %3, c: %5, d: %6)
# CHECK: %7 = irdl.base "#builtin.integer"
# CHECK: %8 = irdl.base "#builtin.float"
# CHECK: irdl.attributes {"x" = %7, "y" = %8}
# CHECK: }
# CHECK: irdl.operation @optional {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, b: optional %0)
# CHECK: irdl.results(out1: %0, out3: %0, out2: optional %0)
# CHECK: }
# CHECK: irdl.operation @optional2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: optional %0)
# CHECK: irdl.results(b: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: %0, c: variadic %0, b: optional %0)
# CHECK: irdl.results(out1: variadic %0, out2: variadic %0, out4: %0, out3: optional %0)
# CHECK: }
# CHECK: irdl.operation @variadic2 {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(a: variadic %0)
# CHECK: irdl.results(b: variadic %0)
# CHECK: }
# CHECK: irdl.operation @mixed {
# CHECK: %0 = irdl.is i32
# CHECK: irdl.operands(in1: %0, in5: %0, in3: optional %0)
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: %2 = irdl.base "#builtin.integer"
# CHECK: irdl.attributes {"in2" = %1, "in4" = %2}
# CHECK: irdl.results(out: %0)
# CHECK: }
# CHECK: irdl.operation @type_var {
# CHECK: %0 = irdl.any
# CHECK: %1 = irdl.is i32
# CHECK: %2 = irdl.is i64
# CHECK: %3 = irdl.any_of(%1, %2)
# CHECK: %4 = irdl.is i8
# CHECK: %5 = irdl.is i16
# CHECK: %6 = irdl.any_of(%4, %5)
# CHECK: %7 = irdl.any_of(%3, %6)
# CHECK: irdl.operands(in1: %0, in2: %0, in3: %3, in4: %7, in5: %6)
# CHECK: }
# CHECK: }
with Context(), Location.unknown():
Test.load()
print(Test._mlir_module)
# CHECK: (self, /, a, b, c, d, x, y, *, loc=None, ip=None)
print(ConstraintOp.__init__.__signature__)
# CHECK: (self, /, out1, out3, a, *, out2=None, b=None, loc=None, ip=None)
print(OptionalOp.__init__.__signature__)
# CHECK: (self, /, *, b=None, a=None, loc=None, ip=None)
print(Optional2Op.__init__.__signature__)
# CHECK: (self, /, out1, out2, out4, a, c, *, out3=None, b=None, loc=None, ip=None)
print(VariadicOp.__init__.__signature__)
# CHECK: (self, /, b, a, *, loc=None, ip=None)
print(Variadic2Op.__init__.__signature__)
# CHECK: (self, /, out, in1, in2, in4, in5, *, in3=None, loc=None, ip=None)
print(MixedOp.__init__.__signature__)
# CHECK: (self, /, in1, in2, in3, in4, in5, *, loc=None, ip=None)
print(TypeVarOp.__init__.__signature__)
# CHECK: (self, /, a, b, c, *, loc=None, ip=None)
print(OptionalButNotKeywordOp.__init__.__signature__)
# CHECK: None None
print(ConstraintOp._ODS_OPERAND_SEGMENTS, ConstraintOp._ODS_RESULT_SEGMENTS)
# CHECK: [1, 0] [1, 1, 0]
print(OptionalOp._ODS_OPERAND_SEGMENTS, OptionalOp._ODS_RESULT_SEGMENTS)
# CHECK: [0] [0]
print(Optional2Op._ODS_OPERAND_SEGMENTS, Optional2Op._ODS_RESULT_SEGMENTS)
# CHECK: [1, -1, 0] [-1, -1, 1, 0]
print(VariadicOp._ODS_OPERAND_SEGMENTS, VariadicOp._ODS_RESULT_SEGMENTS)
# CHECK: [-1] [-1]
print(Variadic2Op._ODS_OPERAND_SEGMENTS, Variadic2Op._ODS_RESULT_SEGMENTS)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
f32 = F32Type.get()
iattr = IntegerAttr.get(i32, 2)
fattr = FloatAttr.get_f32(2.3)
module = Module.create()
with InsertionPoint(module.body):
ione = arith.constant(i32, 1)
fone = arith.constant(f32, 1.2)
# CHECK: "ext_test.constraint"(%c1_i32, %c1_i32, %cst, %c1_i32) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, i32, f32, i32) -> ()
c1 = ConstraintOp(ione, ione, fone, ione, iattr, fattr)
# CHECK: "ext_test.constraint"(%c1_i32, %cst, %cst, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, f32, f32) -> ()
ConstraintOp(ione, fone, fone, fone, iattr, fattr)
# CHECK: ext_test.constraint"(%c1_i32, %cst, %c1_i32, %cst) {x = 2 : i32, y = 2.300000e+00 : f32} : (i32, f32, i32, f32) -> ()
ConstraintOp(ione, fone, ione, fone, iattr, fattr)
# CHECK: %0:2 = "ext_test.optional"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0>, resultSegmentSizes = array<i32: 1, 1, 0>} : (i32) -> (i32, i32)
o1 = OptionalOp(i32, i32, ione)
# CHECK: %1:3 = "ext_test.optional"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1>, resultSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32) -> (i32, i32, i32)
o2 = OptionalOp(i32, i32, ione, out2=i32, b=ione)
# CHECK: ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
o3 = Optional2Op()
# CHECK: %2 = "ext_test.optional2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 1>} : () -> i32
o4 = Optional2Op(b=i32)
# CHECK: "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 0>} : (i32) -> ()
o5 = Optional2Op(a=ione)
# CHECK: %3 = "ext_test.optional2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 1>} : (i32) -> i32
o6 = Optional2Op(b=i32, a=ione)
# CHECK: %4:4 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 2, 0>, resultSegmentSizes = array<i32: 1, 2, 1, 0>} : (i32, i32, i32) -> (i32, i32, i32, i32)
v1 = VariadicOp([i32], [i32, i32], i32, ione, [ione, ione])
# CHECK: %5:5 = "ext_test.variadic"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 1, 1>, resultSegmentSizes = array<i32: 1, 2, 1, 1>} : (i32, i32, i32) -> (i32, i32, i32, i32, i32)
v2 = VariadicOp([i32], [i32, i32], i32, ione, [ione], out3=i32, b=ione)
# CHECK: %6:4 = "ext_test.variadic"(%c1_i32) {operandSegmentSizes = array<i32: 1, 0, 0>, resultSegmentSizes = array<i32: 2, 1, 1, 0>} : (i32) -> (i32, i32, i32, i32)
v3 = VariadicOp([i32, i32], [i32], i32, ione, [])
# CHECK: "ext_test.variadic2"() {operandSegmentSizes = array<i32: 0>, resultSegmentSizes = array<i32: 0>} : () -> ()
v4 = Variadic2Op([], [])
# CHECK: "ext_test.variadic2"(%c1_i32, %c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 3>, resultSegmentSizes = array<i32: 0>} : (i32, i32, i32) -> ()
v5 = Variadic2Op([], [ione, ione, ione])
# CHECK: %7:2 = "ext_test.variadic2"(%c1_i32) {operandSegmentSizes = array<i32: 1>, resultSegmentSizes = array<i32: 2>} : (i32) -> (i32, i32)
v6 = Variadic2Op([i32, i32], [ione])
# CHECK: %8 = "ext_test.mixed"(%c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 0>} : (i32, i32) -> i32
m1 = MixedOp(i32, ione, iattr, iattr, ione)
# CHECK: %9 = "ext_test.mixed"(%c1_i32, %c1_i32, %c1_i32) {in2 = 2 : i32, in4 = 2 : i32, operandSegmentSizes = array<i32: 1, 1, 1>} : (i32, i32, i32) -> i32
m2 = MixedOp(i32, ione, iattr, iattr, ione, in3=ione)
print(module)
assert module.operation.verify()
# CHECK: OpResult(%c1_i32
print(c1.a)
# CHECK: 2 : i32
print(c1.x)
# CHECK: OpResult(%c1_i32
print(o1.a)
# CHECK: None
print(o1.b)
# CHECK: OpResult(%c1_i32
print(o2.b)
# CHECK: 0
print(o1.out1.result_number)
# CHECK: None
print(o1.out2)
# CHECK: 0
print(o2.out1.result_number)
# CHECK: 2
print(o2.out2.result_number)
# CHECK: None
print(o3.a)
# CHECK: OpResult(%c1_i32
print(o5.a)
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)', 'OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v1.c])
# CHECK: ['OpResult(%c1_i32 = arith.constant 1 : i32)']
print([str(i) for i in v2.c])
# CHECK: []
print([str(i) for i in v3.c])
# CHECK: 0 0
print(len(v4.a), len(v4.b))
# CHECK: 3 0
print(len(v5.a), len(v5.b))
# CHECK: 1 2
print(len(v6.a), len(v6.b))
# cases to violate constraits
module = Module.create()
with InsertionPoint(module.body):
try:
c1 = ConstraintOp(ione, ione, fone, ione, iattr)
except TypeError as e:
# CHECK: missing a required argument: 'y'
print(e)
try:
c2 = ConstraintOp(ione, ione, fone, ione, iattr, fattr, ione)
except TypeError as e:
# CHECKtoo many positional arguments
print(e)
# CHECK: TEST: testExtDialectWithRegion
@run
def testExtDialectWithRegion():
class ParentIsIfTrait(DynamicOpTrait):
@staticmethod
def verify_invariants(op) -> bool:
if not isinstance(op.parent.opview, IfOp):
op.location.emit_error(
f"{op.name} should be put inside {IfOp.OPERATION_NAME}"
)
return False
return True
class TestRegion(Dialect, name="ext_region"):
pass
class IfOp(TestRegion.Operation, name="if"):
result: Result[Any]
cond: Operand[IntegerType[1]]
then: Region
else_: Region
class YieldOp(
TestRegion.Operation, name="yield", traits=[IsTerminatorTrait, ParentIsIfTrait]
):
value: Operand[Any]
def verify_invariants(self) -> bool:
if self.parent.results[0].type != self.value.type:
self.location.emit_error(
"result type mismatch between YieldOp and its parent IfOp"
)
return False
return True
class NoTermOp(TestRegion.Operation, name="no_term", traits=[NoTerminatorTrait]):
body: Region
with Context(), Location.unknown():
TestRegion.load()
# CHECK: irdl.dialect @ext_region {
# CHECK: irdl.operation @if {
# CHECK: %0 = irdl.is i1
# CHECK: irdl.operands(cond: %0)
# CHECK: %1 = irdl.any
# CHECK: irdl.results(result: %1)
# CHECK: %2 = irdl.region
# CHECK: %3 = irdl.region
# CHECK: irdl.regions(then: %2, else_: %3)
# CHECK: }
# CHECK: irdl.operation @yield {
# CHECK: %0 = irdl.any
# CHECK: irdl.operands(value: %0)
# CHECK: }
# CHECK: irdl.operation @no_term {
# CHECK: %0 = irdl.region
# CHECK: irdl.regions(body: %0)
# CHECK: }
# CHECK: }
print(TestRegion._mlir_module)
# CHECK: (self, /, result, cond, *, loc=None, ip=None)
print(IfOp.__init__.__signature__)
# CHECK: None None
print(IfOp._ODS_OPERAND_SEGMENTS, IfOp._ODS_RESULT_SEGMENTS)
# CHECK: (2, True)
print(IfOp._ODS_REGIONS)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
yield_ = YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
nt = NoTermOp()
nt.body.blocks.append()
with InsertionPoint(nt.body.blocks[0]):
arith.constant(i32, 4)
# No terminator here
assert module.operation.verify()
# CHECK: module {
# CHECK: %true = arith.constant true
# CHECK: %0 = "ext_region.if"(%true) ({
# CHECK: %c2_i32 = arith.constant 2 : i32
# CHECK: "ext_region.yield"(%c2_i32) : (i32) -> ()
# CHECK: }, {
# CHECK: %c3_i32 = arith.constant 3 : i32
# CHECK: "ext_region.yield"(%c3_i32) : (i32) -> ()
# CHECK: }) : (i1) -> i32
# CHECK: "ext_region.no_term"() ({
# CHECK: %c4_i32 = arith.constant 4 : i32
# CHECK: }) : () -> ()
# CHECK: }
print(module)
# CHECK: True
print(yield_.has_trait(IsTerminatorTrait))
# CHECK: False
print(yield_.has_trait(NoTerminatorTrait))
# CHECK: True
print(yield_.has_trait(ParentIsIfTrait))
# CHECK: False
print(nt.operation.has_trait(IsTerminatorTrait))
# CHECK: True
print(nt.operation.has_trait(NoTerminatorTrait))
# CHECK: False
print(nt.operation.has_trait(ParentIsIfTrait))
# CHECK: False
print(NoTermOp.has_trait(IsTerminatorTrait))
# CHECK: True
print(NoTermOp.has_trait(NoTerminatorTrait))
# CHECK: False
print(NoTermOp.has_trait(ParentIsIfTrait))
# CHECK: %c2_i32 = arith.constant 2 : i32
print(if_.then.blocks[0])
# CHECK: %c3_i32 = arith.constant 3 : i32
print(if_.else_.blocks[0])
# CHECK-LABEL: Testing violation cases
print("Testing violation cases:")
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i32, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: block with no terminator
print(e)
module = Module.create()
with InsertionPoint(module.body):
v = arith.constant(i32, 2)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: ext_region.yield should be put inside ext_region.if
print(e)
module = Module.create()
with InsertionPoint(module.body):
i1 = IntegerType.get_signless(1)
i32 = IntegerType.get_signless(32)
cond = arith.constant(i1, 1)
if_ = IfOp(i1, cond)
if_.then.blocks.append()
if_.else_.blocks.append()
with InsertionPoint(if_.then.blocks[0]):
v = arith.constant(i32, 2)
YieldOp(v)
with InsertionPoint(if_.else_.blocks[0]):
v = arith.constant(i32, 3)
YieldOp(v)
try:
module.operation.verify()
except Exception as e:
# CHECK: Verification failed:
# CHECK: result type mismatch
print(e)
# CHECK: TEST: testExtDialectWithType
@run
def testExtDialectWithType():
class TestType(Dialect, name="ext_type"):
pass
class Array(TestType.Type, name="array"):
elem_type: IntegerType[32] | IntegerType[64]
length: IntegerAttr
class MakeArrayOp(TestType.Operation, name="make_array"):
arr: Result[Array]
class MakeArray3Op(TestType.Operation, name="make_array3"):
arr: Result[
Array[IntegerType[32], IntegerAttr[IntegerType[32], 3]]
] = infer_result()
with Context(), Location.unknown():
TestType.load()
# CHECK: irdl.dialect @ext_type {
# CHECK: irdl.type @array {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is i64
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.base "#builtin.integer"
# CHECK: irdl.parameters(elem_type: %2, length: %3)
# CHECK: }
# CHECK: irdl.operation @make_array {
# CHECK: %0 = irdl.base @ext_type::@array
# CHECK: irdl.results(arr: %0)
# CHECK: }
# CHECK: irdl.operation @make_array3 {
# CHECK: %0 = irdl.is i32
# CHECK: %1 = irdl.is 3 : i32
# CHECK: %2 = irdl.parametric @ext_type::@array<%0, %1>
# CHECK: irdl.results(arr: %2)
# CHECK: }
# CHECK: }
print(TestType._mlir_module)
# CHECK: ext_type.array
print(Array.type_name)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
a4 = Array.get(i32, IntegerAttr.get(i32, 4))
a6 = Array.get(i64, IntegerAttr.get(i32, 6))
# CHECK: !ext_type.array<i32, 4 : i32>
print(a4)
# CHECK: !ext_type.array<i64, 6 : i32>
print(a6)
# CHECK: i32
print(a4.elem_type)
# CHECK: 4 : i32
print(a4.length)
# CHECK: i64
print(a6.elem_type)
# CHECK: 6 : i32
print(a6.length)
# CHECK: <locals>.Array
print(type(Type(a4).maybe_downcast()))
module = Module.create()
with InsertionPoint(module.body):
MakeArrayOp(a4)
MakeArrayOp(a6)
MakeArray3Op()
# CHECK: %0 = "ext_type.make_array"() : () -> !ext_type.array<i32, 4 : i32>
# CHECK: %1 = "ext_type.make_array"() : () -> !ext_type.array<i64, 6 : i32>
# CHECK: %2 = "ext_type.make_array3"() : () -> !ext_type.array<i32, 3 : i32>
assert module.operation.verify()
print(module)
# CHECK: TEST: testExtDialectWithAttr
@run
def testExtDialectWithAttr():
class TestAttr(Dialect, name="ext_attr"):
pass
class IntPair(TestAttr.Attribute, name="pair"):
first: IntegerAttr
second: IntegerAttr
class StrPair(TestAttr.Attribute, name="str_pair"):
first: StringAttr
second: StringAttr
class Op1(TestAttr.Operation, name="op1"):
pair: IntPair
class Op2(TestAttr.Operation, name="op2"):
pair: StrPair
pair2: StrPair[StringAttr["a"], StringAttr["b"]]
with Context(), Location.unknown():
TestAttr.load()
# CHECK: irdl.dialect @ext_attr {
# CHECK: irdl.attribute @pair {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: %1 = irdl.base "#builtin.integer"
# CHECK: irdl.parameters(first: %0, second: %1)
# CHECK: }
# CHECK: irdl.attribute @str_pair {
# CHECK: %0 = irdl.base "#builtin.string"
# CHECK: %1 = irdl.base "#builtin.string"
# CHECK: irdl.parameters(first: %0, second: %1)
# CHECK: }
# CHECK: irdl.operation @op1 {
# CHECK: %0 = irdl.base @ext_attr::@pair
# CHECK: irdl.attributes {"pair" = %0}
# CHECK: }
# CHECK: irdl.operation @op2 {
# CHECK: %0 = irdl.base @ext_attr::@str_pair
# CHECK: %1 = irdl.is "a"
# CHECK: %2 = irdl.is "b"
# CHECK: %3 = irdl.parametric @ext_attr::@str_pair<%1, %2>
# CHECK: irdl.attributes {"pair" = %0, "pair2" = %3}
# CHECK: }
# CHECK: }
print(TestAttr._mlir_module)
# CHECK: ext_attr.pair
print(IntPair.attr_name)
# CHECK: ext_attr.str_pair
print(StrPair.attr_name)
ip = IntPair.get(
IntegerAttr.get(IntegerType.get_signless(32), 1),
IntegerAttr.get(IntegerType.get_signless(32), 2),
)
sp = StrPair.get(StringAttr.get("hello"), StringAttr.get("world"))
# CHECK: #ext_attr.pair<1 : i32, 2 : i32>
print(ip)
# CHECK: #ext_attr.str_pair<"hello", "world">
print(sp)
# CHECK: "hello"
print(sp.first)
# CHECK: "world"
print(sp.second)
sp2 = ir.Attribute(sp).maybe_downcast()
# CHECK: <locals>.StrPair
print(type(sp2))
# CHECK: #ext_attr.str_pair<"hello", "world">
print(str(sp2))
module = Module.create()
with InsertionPoint(module.body):
Op1(ip)
p2 = StrPair.get(StringAttr.get("a"), StringAttr.get("b"))
Op2(sp, p2)
assert module.operation.verify()
# CHECK: "ext_attr.op1"() {pair = #ext_attr.pair<1 : i32, 2 : i32>} : () -> ()
# CHECK: "ext_attr.op2"() {pair = #ext_attr.str_pair<"hello", "world">, pair2 = #ext_attr.str_pair<"a", "b">} : () -> ()
print(module)
# CHECK: TEST: testExtDialectWithInvalidOp
@run
def testExtDialectWithInvalidOp():
class TestInvalid(Dialect, name="ext_invalid"):
pass
try:
class InferTypeBeforePositionalOp(
TestInvalid.Operation, name="infer_before_pos"
):
res: Result[IntegerType[32]] = infer_result()
a: Operand[IntegerType[32]]
except ValueError as e:
# CHECK: wrong parameter order
print(e)
try:
class AssignNoneOnNonOptionalOp(
TestInvalid.Operation, name="assign_none_on_non_optional"
):
a: Operand[IntegerType[32]] = None
except ValueError as e:
# CHECK: only optional operand can be set to None
print(e)
try:
class AssignNoneOnnAttributeOp(
TestInvalid.Operation, name="assign_none_on_attribute"
):
a: IntegerAttr = None
except ValueError as e:
# CHECK: only optional attribute can be set to None
print(e)
try:
class CannotInferTypeOp(TestInvalid.Operation, name="cannot_infer_type"):
a: Result[IntegerType] = infer_result()
except TypeError as e:
# CHECK: unsupported type for inferring
print(e)
try:
class WrongFieldSpecifierOp(
TestInvalid.Operation, name="wrong_field_specifier"
):
a: Result[IntegerType] = operand()
except TypeError as e:
# CHECK: only `result` field specifier can be used for result fields
print(e)
try:
class WrongFieldSpecifierOp2(
TestInvalid.Operation, name="wrong_field_specifier2"
):
a: IntegerAttr = operand()
except TypeError as e:
# CHECK: only `attribute` field specifier can be used for attribute fields
print(e)
# CHECK: TEST: testExtDialectWithAttrInOp
@run
def testExtDialectWithAttrInOp():
class TestAttrInOp(Dialect, name="ext_attr_in_op"):
pass
class OpWithAttr(TestAttrInOp.Operation, name="op_with_attr"):
a: IntegerAttr | StringAttr
b: IntegerType[32] | IntegerType[64]
with Context(), Location.unknown():
TestAttrInOp.load()
# CHECK: irdl.dialect @ext_attr_in_op {
# CHECK: irdl.operation @op_with_attr {
# CHECK: %0 = irdl.base "#builtin.integer"
# CHECK: %1 = irdl.base "#builtin.string"
# CHECK: %2 = irdl.any_of(%0, %1)
# CHECK: %3 = irdl.is i32
# CHECK: %4 = irdl.is i64
# CHECK: %5 = irdl.any_of(%3, %4)
# CHECK: irdl.attributes {"a" = %2, "b" = %5}
# CHECK: }
# CHECK: }
print(TestAttrInOp._mlir_module)
i32 = IntegerType.get_signless(32)
i64 = IntegerType.get_signless(64)
iattr = IntegerAttr.get(i32, 42)
sattr = StringAttr.get("hello")
module = Module.create()
with InsertionPoint(module.body):
OpWithAttr(iattr, TypeAttr.get(i32))
OpWithAttr(sattr, TypeAttr.get(i64))
assert module.operation.verify()
# CHECK: "ext_attr_in_op.op_with_attr"() {a = 42 : i32, b = i32} : () -> ()
# CHECK: "ext_attr_in_op.op_with_attr"() {a = "hello", b = i64} : () -> ()
print(module)
@run
def testExtDialectFieldSpecifiers():
class TestFieldSpecifiers(Dialect, name="ext_field_specifiers"):
pass
class OperandSpecifierOp(TestFieldSpecifiers.Operation, name="operand_specifier"):
a: Operand[IntegerType[32]] = operand()
b: Optional[Operand[IntegerType[32]]] = None
c: Operand[IntegerType[32]] = operand(kw_only=True)
class ResultSpecifierOp(TestFieldSpecifiers.Operation, name="result_specifier"):
a: Result[IntegerType[32]] = result()
b: Result[IntegerType[16]] = infer_result()
c: Result[IntegerType] = result(
default_factory=lambda: IntegerType.get_signless(8)
)
d: Sequence[Result[IntegerType]] = result(default_factory=list)
e: Result[IntegerType[32]] = result(kw_only=True)
class AttributeSpecifierOp(
TestFieldSpecifiers.Operation, name="attribute_specifier"
):
a: IntegerAttr = attribute()
b: IntegerAttr = attribute(
default_factory=lambda: IntegerAttr.get(IntegerType.get_signless(32), 42)
)
c: StringAttr["a"] | StringAttr["b"] = attribute(
default_factory=lambda: StringAttr.get("a"), kw_only=True
)
d: IntegerAttr = attribute(kw_only=True)
with Context(), Location.unknown():
TestFieldSpecifiers.load()
# CHECK: (self, /, a, *, b=None, c, loc=None, ip=None)
print(OperandSpecifierOp.__init__.__signature__)
# CHECK: (self, /, a, *, b=None, c=None, d=None, e, loc=None, ip=None)
print(ResultSpecifierOp.__init__.__signature__)
# CHECK: (self, /, a, *, b=None, c=None, d, loc=None, ip=None)
print(AttributeSpecifierOp.__init__.__signature__)
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
one = arith.constant(i32, 1)
OperandSpecifierOp(one, c=one)
ResultSpecifierOp(i32, e=i32)
AttributeSpecifierOp(IntegerAttr.get(i32, 43), d=IntegerAttr.get(i32, 100))
assert module.operation.verify()
# CHECK: "ext_field_specifiers.operand_specifier"(%c1_i32, %c1_i32) {operandSegmentSizes = array<i32: 1, 0, 1>} : (i32, i32) -> ()
# CHECK: %0:4 = "ext_field_specifiers.result_specifier"() {resultSegmentSizes = array<i32: 1, 1, 1, 0, 1>} : () -> (i32, i16, i8, i32)
# CHECK: "ext_field_specifiers.attribute_specifier"() {a = 43 : i32, b = 42 : i32, c = "a", d = 100 : i32} : () -> ()
print(module)