Files
llvm-project/mlir/test/python/dialects/python_test.py
Ingo Müller 4440e87bae [mlir:python] Fix crash in from_python in type casters. (#191764)
This PR fixes a crash due to a failed assertion in the `from_python`
implementations of the type casters. The assertion obviously only
triggers if assertions are enabled, which isn't the case for many Python
installations, *and* if a Python capsule of the wrong type is attempted
to be used, so this this isn't triggered easily. The problem is that the
conversion from Python capsules may set the Python error indicator but
the callers of the type casters do not expect that. In fact, if there
are several operloads of a function, the first may cause the error
indicator to be set and the second runs into the assertion. The fix is
to unset the error indicator after a failed capsule conversion, which is
indicated with the return value of the function anyways.

In alternative fix would be to unset the error indicator *inside* the
`mlirPythonCapsuleTo*` functions; however, their documentations does say
that the Python error indicator is set, so I assume that some callers
may *want* to see the indicator and that the responsibility to handle it
is on them.

Signed-off-by: Ingo Müller <ingomueller@google.com>
2026-04-13 20:40:32 +02:00

1071 lines
44 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
import sys
import typing
from typing import Union, Optional
from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith
from mlir._mlir_libs._mlirPythonTestNanobind import (
TestAttr,
TestType,
TestTensorValue,
TestIntegerRankedTensorType,
take_module_or_operation,
)
test.register_python_test_dialect(get_dialect_registry())
def run(f):
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: testAttributes
@run
def testAttributes():
with Context() as ctx, Location.unknown():
#
# Check op construction with attributes.
#
i32 = IntegerType.get_signless(32)
one = IntegerAttr.get(i32, 1)
two = IntegerAttr.get(i32, 2)
unit = UnitAttr.get()
# CHECK: python_test.attributed_op {
# CHECK-DAG: mandatory_i32 = 1 : i32
# CHECK-DAG: optional_i32 = 2 : i32
# CHECK-DAG: unit
# CHECK: }
op = test.AttributedOp(one, optional_i32=two, unit=unit)
print(f"{op}")
# CHECK: python_test.attributed_op {
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
op2 = test.AttributedOp(two)
print(f"{op2}")
#
# Check generic "attributes" access and mutation.
#
assert "additional" not in op.attributes
# CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 1 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = one
print(f"{op2}")
# CHECK: python_test.attributed_op {
# CHECK-DAG: additional = 2 : i32
# CHECK-DAG: mandatory_i32 = 2 : i32
# CHECK: }
op2.attributes["additional"] = two
print(f"{op2}")
# CHECK: python_test.attributed_op {
# CHECK-NOT: additional = 2 : i32
# CHECK: mandatory_i32 = 2 : i32
# CHECK: }
del op2.attributes["additional"]
print(f"{op2}")
try:
print(op.attributes["additional"])
except KeyError:
pass
else:
assert False, "expected KeyError on unknown attribute key"
#
# Check accessors to defined attributes.
#
# CHECK: Mandatory: 1
# CHECK: Optional: 2
# CHECK: Unit: True
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32.value}")
print(f"Unit: {op.unit}")
# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
print(f"Mandatory: {op2.mandatory_i32.value}")
print(f"Optional: {op2.optional_i32}")
print(f"Unit: {op2.unit}")
# CHECK: Mandatory: 2
# CHECK: Optional: None
# CHECK: Unit: False
op.mandatory_i32 = two
op.optional_i32 = None
op.unit = False
print(f"Mandatory: {op.mandatory_i32.value}")
print(f"Optional: {op.optional_i32}")
print(f"Unit: {op.unit}")
assert "optional_i32" not in op.attributes
assert "unit" not in op.attributes
try:
op.mandatory_i32 = None
except ValueError:
pass
else:
assert False, "expected ValueError on setting a mandatory attribute to None"
# CHECK: Optional: 2
op.optional_i32 = two
print(f"Optional: {op.optional_i32.value}")
# CHECK: Optional: None
del op.optional_i32
print(f"Optional: {op.optional_i32}")
# CHECK: Unit: False
op.unit = None
print(f"Unit: {op.unit}")
assert "unit" not in op.attributes
# CHECK: Unit: True
op.unit = True
print(f"Unit: {op.unit}")
# CHECK: Unit: False
del op.unit
print(f"Unit: {op.unit}")
# CHECK-LABEL: TEST: attrBuilder
@run
def attrBuilder():
with Context() as ctx, Location.unknown():
# CHECK: python_test.attributes_op
op = test.AttributesOp(
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
x_affinemap=AffineMap.get_constant(2),
# CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
x_affinemaparr=[AffineMap.get_identity(3)],
# CHECK-DAG: x_arr = [true, "x"]
x_arr=[BoolAttr.get(True), StringAttr.get("x")],
x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true]
x_bool=True, # CHECK-DAG: x_bool = true
x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false>
x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22>
# CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
x_df32arr=[23, 24],
# CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
x_df64arr=[25, 26],
x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1>
# CHECK-DAG: x_di64arr = array<i64: 1, 2>
x_di64arr=[1, 2],
x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3>
# CHECK-DAG: x_dictarr = [{a = false}]
x_dictarr=[{"a": BoolAttr.get(False)}],
x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true}
x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32
# CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
x_f32arr=[2.0, 3.0],
x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64
x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
# CHECK-DAG: x_f64elems = dense<[8.000000e+00, 1.600000e+01]> : tensor<2xf64>
x_f64elems=[8.0, 16.0],
# CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
x_flatsymrefarr=["symbol1", "symbol2"],
x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3
x_i1=0, # CHECK-DAG: x_i1 = false
x_i16=42, # CHECK-DAG: x_i16 = 42 : i16
x_i32=6, # CHECK-DAG: x_i32 = 6 : i32
x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
x_i64=9, # CHECK-DAG: x_i64 = 9 : i64
x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8]
x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11]
x_i8=11, # CHECK-DAG: x_i8 = 11 : i8
x_idx=10, # CHECK-DAG: x_idx = 10 : index
# CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
x_idxelems=[11, 12],
# CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
x_idxlistarr=[[13], [14, 15]],
x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1
x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16
x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32
x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64
x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8
x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"]
x_str="hello world!", # CHECK-DAG: x_str = "hello world!"
# CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
x_symrefarr=["flatsym", ["deep", "sym"]],
x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2
x_sym="symbol", # CHECK-DAG: x_sym = "symbol"
x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32]
x_type=F64Type.get(), # CHECK-DAG: x_type = f64
x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1
x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16
x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32
x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64
x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8
x_unit=True, # CHECK-DAG: x_unit
)
op.verify()
op.print(use_local_scope=True)
# fmt: off
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"] is ArrayAttr
assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
assert typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"] is AffineMapAttr
assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"] is ArrayAttr
assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"] is ArrayAttr
assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
assert typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"] is BoolAttr
assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"] is DenseBoolArrayAttr
assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"] is DenseF32ArrayAttr
assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"] is DenseF64ArrayAttr
assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"] is DenseI16ArrayAttr
assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"] is DenseI32ArrayAttr
assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"] is DenseI64ArrayAttr
assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"] is DenseI8ArrayAttr
assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"] is ArrayAttr
assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
assert typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"] is DictAttr
assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"] is ArrayAttr
assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
assert typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"] is FloatAttr
assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"] is ArrayAttr
assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
assert typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"] is FloatAttr
assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"] is DenseFPElementsAttr
assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"] is ArrayAttr
assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"] is FlatSymbolRefAttr
assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"] is IntegerAttr
assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
assert typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"] is BoolAttr
assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"] is ArrayAttr
assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"] is IntegerAttr
assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"] is ArrayAttr
assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"] is IntegerAttr
assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"] is ArrayAttr
assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"] is IntegerAttr
assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"] is IntegerAttr
assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
assert typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"] is DenseIntElementsAttr
assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"] is ArrayAttr
assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"] is IntegerAttr
assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"] is IntegerAttr
assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"] is IntegerAttr
assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"] is IntegerAttr
assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"] is IntegerAttr
assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"] is ArrayAttr
assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
assert typing.get_type_hints(test.AttributesOp.x_str.fget)["return"] is StringAttr
assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
assert typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"] is StringAttr
assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"] is ArrayAttr
assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
assert typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"] is SymbolRefAttr
assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
assert typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"] is ArrayAttr
assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
assert typing.get_type_hints(test.AttributesOp.x_type.fget)["return"] is TypeAttr
assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"] is IntegerAttr
assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"] is IntegerAttr
assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"] is IntegerAttr
assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"] is IntegerAttr
assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]
assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
assert typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"] is IntegerAttr
assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
# fmt: on
# CHECK-LABEL: TEST: inferReturnTypes
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp()
dummy = test.DummyOp()
# CHECK: [Type(i32), Type(i64)]
iface = InferTypeOpInterface(op)
print(iface.inferReturnTypes())
# CHECK: [Type(i32), Type(i64)]
iface_static = InferTypeOpInterface(test.InferResultsOp)
print(iface.inferReturnTypes())
assert isinstance(iface.opview, test.InferResultsOp)
assert iface.opview == iface.operation.opview
try:
iface_static.opview
except TypeError:
pass
else:
assert False, (
"not expected to be able to obtain an opview from a static" " interface"
)
try:
InferTypeOpInterface(dummy)
except ValueError:
pass
else:
assert False, "not expected dummy op to implement the interface"
try:
InferTypeOpInterface(test.DummyOp)
except ValueError:
pass
else:
assert False, "not expected dummy op class to implement the interface"
# CHECK-LABEL: TEST: resultTypesDefinedByTraits
@run
def resultTypesDefinedByTraits():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
inferred = test.InferResultsOp()
# CHECK: i32 i64
print(inferred.single.type, inferred.doubled.type)
same = test.SameOperandAndResultTypeOp([inferred.results[0]])
# CHECK-COUNT-2: i32
print(same.one.type)
print(same.two.type)
assert (
typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
"return"
]
is OpResult
)
assert type(same.one) is OpResult
first_type_attr = test.FirstAttrDeriveTypeAttrOp(
inferred.results[1], TypeAttr.get(IndexType.get())
)
# CHECK-COUNT-2: index
print(first_type_attr.one.type)
print(first_type_attr.two.type)
first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
# CHECK-COUNT-3: f32
print(first_attr.one.type)
print(first_attr.two.type)
print(first_attr.three.type)
implied = test.InferResultsImpliedOp()
# CHECK: i32
print(implied.integer.type)
# CHECK: f64
print(implied.flt.type)
# CHECK: index
print(implied.index.type)
# provide the result types to avoid inferring them
f64 = F64Type.get()
no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
# CHECK-COUNT-3: f64
print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)
no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
# CHECK: f32 index
print(no_infer.single.type, no_infer.doubled.type)
# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
op1 = test.OptionalOperandOp()
# CHECK: op1.input is None: True
print(f"op1.input is None: {op1.input is None}")
assert (
typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
is Optional[Value]
)
assert (
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
== OpResult[IntegerType]
)
assert type(op1.result) is OpResult
op2 = test.OptionalOperandOp(input=op1)
# CHECK: op2.input is None: False
print(f"op2.input is None: {op2.input is None}")
# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
with Context() as ctx, Location.unknown():
a = TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
# CHECK: python_test.custom_attributed_op {
# CHECK: #python_test.test_attr
# CHECK: }
op2 = test.CustomAttributedOp(a)
print(f"{op2}")
# CHECK: #python_test.test_attr
print(f"{op2.test_attr}")
# CHECK: TestAttr(#python_test.test_attr)
print(repr(op2.test_attr))
# The following cast must not assert.
b = TestAttr(a)
unit = UnitAttr.get()
try:
TestAttr(unit)
except ValueError as e:
assert "Cannot cast attribute to TestAttr" in str(e)
else:
raise
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
TestAttr(42)
except TypeError as e:
assert (
"__init__(): incompatible function arguments. The following argument types are supported"
in str(e)
)
assert (
"__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
in str(e)
)
assert (
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
in str(e)
)
else:
raise
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
TestAttr(42, 56)
except TypeError:
pass
else:
raise
@run
def testCustomType():
with Context() as ctx:
a = TestType.get()
# CHECK: !python_test.test_type
print(a)
# The following cast must not assert.
b = TestType(a)
# Instance custom types should have typeids
assert isinstance(b.typeid, TypeID)
i8 = IntegerType.get_signless(8)
try:
TestType(i8)
except ValueError as e:
assert "Cannot cast type to TestType" in str(e)
else:
raise
# The following must trigger a TypeError from our adaptors and must not
# crash.
try:
TestType(42)
except TypeError as e:
assert (
"__init__(): incompatible function arguments. The following argument types are supported"
in str(e)
)
assert (
"__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
in str(e)
)
assert (
"Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
in str(e)
)
else:
raise
# The following must trigger a TypeError from pybind (therefore, not
# checking its message) and must not crash.
try:
TestType(42, 56)
except TypeError:
pass
else:
raise
@run
# CHECK-LABEL: TEST: testValue
def testValue():
# Check that Value is a generic class at runtime.
assert hasattr(Value, "__class_getitem__")
@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
i8 = IntegerType.get_signless(8)
class Tensor(TestTensorValue):
def __str__(self):
return super().__str__().replace("Value", "Tensor")
module = Module.create()
with InsertionPoint(module.body):
t = tensor.EmptyOp([10, 10], i8).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(Value(t))
tt = Tensor(t)
# CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(tt)
# CHECK: False
print(tt.is_null())
# Classes of custom types that inherit from concrete types should have
# static_typeid
assert isinstance(TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
assert TestIntegerRankedTensorType.static_typeid == t.type.typeid
d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
print(d)
# CHECK: TestTensorValue
print(repr(d))
# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
def inferReturnTypeComponents():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
with InsertionPoint(module.body):
resultType = UnrankedTensorType.get(i32)
operandTypes = [
RankedTensorType.get([1, 3, 10, 10], i32),
UnrankedTensorType.get(i32),
]
f = func.FuncOp(
"test_inferReturnTypeComponents", (operandTypes, [resultType])
)
entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
with InsertionPoint(entry_block):
ranked_op = test.InferShapedTypeComponentsOp(
resultType, entry_block.arguments[0]
)
unranked_op = test.InferShapedTypeComponentsOp(
resultType, entry_block.arguments[1]
)
# CHECK: has rank: True
# CHECK: rank: 4
# CHECK: element type: i32
# CHECK: shape: [1, 3, 10, 10]
iface = InferShapedTypeOpInterface(ranked_op)
shaped_type_components = iface.inferReturnTypeComponents(
operands=[ranked_op.operand]
)[0]
print("has rank:", shaped_type_components.has_rank)
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
# CHECK: has rank: False
# CHECK: rank: None
# CHECK: element type: i32
# CHECK: shape: None
iface = InferShapedTypeOpInterface(unranked_op)
shaped_type_components = iface.inferReturnTypeComponents(
operands=[unranked_op.operand]
)[0]
print("has rank:", shaped_type_components.has_rank)
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
a = TestType.get()
assert a.typeid is not None
b = Type.parse("!python_test.test_type")
# CHECK: !python_test.test_type
print(b)
# CHECK: TestType(!python_test.test_type)
print(repr(b))
c = TestIntegerRankedTensorType.get([10, 10], 5)
# CHECK: tensor<10x10xi5>
print(c)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(c))
# CHECK: Type caster is already registered
try:
@register_type_caster(c.typeid)
def type_caster(pytype):
return TestIntegerRankedTensorType(pytype)
except RuntimeError as e:
print(e)
# python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
# So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
@register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
return RankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
# CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
print("ranked tensor type", repr(d.type))
@register_type_caster(c.typeid, replace=True)
def type_caster(pytype):
return TestIntegerRankedTensorType(pytype)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(d.type))
# CHECK-LABEL: TEST: testInferTypeOpInterface
@run
def testInferTypeOpInterface():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i64 = IntegerType.get_signless(64)
zero = arith.ConstantOp(i64, 0)
one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
# CHECK: i32
print(one_operand.result.type)
two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
# CHECK: f32
print(two_operands.result.type)
assert (
typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
is OpResult
)
assert (
type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
is OpResult
)
# CHECK-LABEL: TEST: testVariadicOperandAccess
@run
def testVariadicOperandAccess():
def values(lst):
return [str(e) for e in lst]
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i32 = IntegerType.get_signless(32)
zero = arith.ConstantOp(i32, 0)
one = arith.ConstantOp(i32, 1)
two = arith.ConstantOp(i32, 2)
three = arith.ConstantOp(i32, 3)
four = arith.ConstantOp(i32, 4)
variadic_operands = test.SameVariadicOperandSizeOp(
[zero, one], two, [three, four]
)
# CHECK: OpResult(%{{.*}} = arith.constant 2 : i32)
print(variadic_operands.non_variadic)
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
"return"
]
is Value
)
assert type(variadic_operands.non_variadic) is OpResult
# CHECK: ['OpResult(%{{.*}} = arith.constant 0 : i32)', 'OpResult(%{{.*}} = arith.constant 1 : i32)']
print(values(variadic_operands.variadic1))
assert (
typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
"return"
]
is OpOperandList
)
assert type(variadic_operands.variadic1) is OpOperandList
# CHECK: ['OpResult(%{{.*}} = arith.constant 3 : i32)', 'OpResult(%{{.*}} = arith.constant 4 : i32)']
print(values(variadic_operands.variadic2))
assert type(variadic_operands.variadic2) is OpOperandList
assert (
typing.get_type_hints(test.same_variadic_operand)["return"]
is test.SameVariadicOperandSizeOp
)
assert (
type(test.same_variadic_operand([zero, one], two, [three, four]))
is test.SameVariadicOperandSizeOp
)
# CHECK-LABEL: TEST: testVariadicResultAccess
@run
def testVariadicResultAccess():
def types(lst):
return [e.type for e in lst]
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
i = [IntegerType.get_signless(k) for k in range(7)]
# Test Variadic-Fixed-Variadic
op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
# CHECK: i2
print(op.non_variadic.type)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4)]
print(types(op.variadic2))
assert (
typing.get_type_hints(test.same_variadic_result_vfv)["return"]
== Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
)
assert (
type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
is OpResultList
)
# Test Variadic-Variadic-Variadic
op = test.SameVariadicResultSizeOpVVV(
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
)
# CHECK: [IntegerType(i0), IntegerType(i1)]
print(types(op.variadic1))
# CHECK: [IntegerType(i2), IntegerType(i3)]
print(types(op.variadic2))
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic3))
# Test Fixed-Fixed-Variadic
op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
print(types(op.variadic))
assert (
typing.get_type_hints(test.SameVariadicResultSizeOpFFV.variadic.fget)[
"return"
]
is OpResultList
)
assert type(op.variadic) is OpResultList
# Test Variadic-Variadic-Fixed
op = test.SameVariadicResultSizeOpVVF(
[i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
)
# CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed
op = test.SameVariadicResultSizeOpFVFVF(
i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
)
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1), IntegerType(i2)]
print(types(op.variadic1))
# CHECK: i3
print(op.non_variadic2.type)
# CHECK: [IntegerType(i4), IntegerType(i5)]
print(types(op.variadic2))
# CHECK: i6
print(op.non_variadic3.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: []
print(types(op.variadic1))
# CHECK: i1
print(op.non_variadic2.type)
# CHECK: []
print(types(op.variadic2))
# CHECK: i2
print(op.non_variadic3.type)
# Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
# CHECK: i0
print(op.non_variadic1.type)
# CHECK: [IntegerType(i1)]
print(types(op.variadic1))
# CHECK: i2
print(op.non_variadic2.type)
# CHECK: [IntegerType(i3)]
print(types(op.variadic2))
# CHECK: i4
print(op.non_variadic3.type)
assert (
typing.get_type_hints(test.results_variadic)["return"]
== Union[OpResult, OpResultList, test.ResultsVariadicOp]
)
assert type(test.results_variadic([i[0]])) is OpResult
op_res_variadic = test.ResultsVariadicOp([i[0]])
assert (
typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
is OpResultList
)
assert type(op_res_variadic.res) is OpResultList
# CHECK-LABEL: TEST: testVariadicAndNormalRegionOp
@run
def testVariadicAndNormalRegionOp():
with Context() as ctx, Location.unknown(ctx):
module = Module.create()
with InsertionPoint(module.body):
region_op = test.VariadicAndNormalRegionOp(2)
assert (
typing.get_type_hints(test.VariadicAndNormalRegionOp.region.fget)[
"return"
]
is Region
)
assert type(region_op.region) is Region
assert (
typing.get_type_hints(test.VariadicAndNormalRegionOp.variadic.fget)[
"return"
]
is RegionSequence
)
assert type(region_op.variadic) is RegionSequence
assert isinstance(region_op.opview, OpView)
assert isinstance(region_op.operation.opview, OpView)
# Regression test for the dirty-error-state crash in `NanobindAdaptors.h`
# `from_python` type casters (#191764).
#
# !!! This only fails with a debug version of Python. !!!
#
# Uses an overloaded function: overload 1 takes `MlirOperation`, overload 2
# takes `MlirModule`. When called with an `ir.Module`:
#
# 1. `nanobind` tries overload 1 (`MlirOperation`). `from_python` gets the
# `Module`'s `_CAPIPtr` capsule, then `mlirPythonCapsuleToOperation` calls
# `PyCapsule_GetPointer` with `"mlir.ir.Operation._CAPIPtr"` — but the
# capsule is named `"mlir.ir.Module._CAPIPtr"`. `PyCapsule_GetPointer`
# returns `NULL` and sets `PyErr_Occurred()`. `from_python` returns `false`.
#
# 2. `nanobind` tries overload 2 (`MlirModule`). `from_python` calls
# `mlirApiObjectToCapsule` --> `nanobind::getattr(obj, "_CAPIPtr")` -->
# `_PyType_LookupRef`.
#
# Without the fix:
# `_PyType_LookupRef` asserts `!PyErr_Occurred()` --> `SIGABRT`.
#
# With the fix (`PyErr_Clear` in `from_python` after failed capsule conversion):
# Overload 2 succeeds and returns `"module"`.
# CHECK-LABEL: testOverloadWithWrongPythonCapsule
@run
def testOverloadWithWrongPythonCapsule():
with Context():
module = Module.parse("module {}")
# CHECK: result = module
result = take_module_or_operation(module)
print(f"result = {result}")