Files
llvm-project/mlir/test/python/dialects/transform_op_interface.py
Twice 2874480930 [MLIR][Python] Migrate result(infer_type=True) to a new field specifier (#191849)
Currrently the signature of `result(..)` is:
```python
result(*, infer_type: bool = False, default_factory: Callable[[], Any] | None = None, kw_only: bool = False) -> Result
```

so when users use `result(infer_type=True)`, the type checkers will
still get `kw_only=False` (from the signature), but actually the
`kw_only` should be `True` (it should follow the value of `infer_type`).
users can use `result(infer_type=True, kw_only=True)` but it's
unnecessarily verbose.

So it may introduce an incompatibility when we start to use
`dataclass_transform`. currently it's fine because we just don't use
`dataclass_transform`. But when we use, we may require a breaking
change.

This PR migrates such use to a new field specifier named
`infer_result()`.
2026-04-14 10:45:28 +08:00

495 lines
19 KiB
Python

# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
from typing import Sequence
from contextlib import contextmanager
from mlir import ir
from mlir.dialects import index, transform, func, arith, ext
from mlir.dialects.transform import (
DiagnosedSilenceableFailure,
AnyOpType,
AnyValueType,
AnyParamType,
structured,
interpreter,
)
class MyTransform(ext.Dialect, name="my_transform"):
pass
def run(emit_schedule):
print(f"Test: {emit_schedule.__name__}")
with ir.Context() as ctx, ir.Location.unknown():
payload = emit_payload()
MyTransform.load(reload=True)
GetNamedAttributeOp.attach_interface_impls(ctx)
PrintParamOp.attach_interface_impls(ctx)
# NB: Other newly defined my_transform ops have their interfaces attached
# in their respective test functions.
schedule = emit_schedule()
interpreter.apply_named_sequence(
payload,
_named_seq := schedule.operation.regions[0].blocks[0].operations[0],
schedule,
)
# Payload used by all tests
def emit_payload():
payload_module = ir.Module.create()
with ir.InsertionPoint(payload_module.body):
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(f32, f32, results=[f32])
def name_of_func(a, b):
c = arith.addf(a, b)
i32 = ir.IntegerType.get_signless(32)
arith.constant(i32, 42)
arith.constant(i32, 24)
func.ReturnOp([c])
return payload_module
@contextmanager
def schedule_boilerplate():
schedule = ir.Module.create()
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
with ir.InsertionPoint(schedule.body):
named_sequence = transform.NamedSequenceOp(
"__transform_main",
[AnyOpType.get()],
[AnyOpType.get()],
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
)
with ir.InsertionPoint(named_sequence.body):
yield schedule, named_sequence
# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
# Used by most ops defined below.
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.only_reads_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.only_reads_payload(effects)
# Demonstration of a TransformOpInterface-implementing op that gets named attributes
# from target ops and produces them as param handles.
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
attr_as_param: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
@classmethod
def attach_interface_impls(cls, ctx=None):
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "GetNamedAttributeOp",
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
associated_attrs = []
for target_op in target_ops:
assoc_attr = target_op.attributes.get(op.attr_name.value)
if assoc_attr is None:
return DiagnosedSilenceableFailure.RecoverableFailure
associated_attrs.append(assoc_attr)
results.set_params(op.attr_as_param, associated_attrs)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
class PrintParamOp(MyTransform.Operation, name="print_param"):
target: ext.Operand[transform.AnyParamType]
name: ir.StringAttr
@classmethod
def attach_interface_impls(cls, ctx=None):
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "PrintParamOp",
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_attrs = state.get_params(op.target)
print(f"[[[ IR printer: {op.name.value} ]]]")
for attr in target_attrs:
print(attr)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
# Syntax for an op with one op handle operand and one op handle result.
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@run
def OneOpInOneOpOutTransformOpInterface():
"""Tests a simple passthrough interface implementation.
Checks that the target ops are correctly identified and passed as results.
"""
# Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OneOpInOneOpOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
target_names = [t.name.value for t in target_ops]
print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
results.set_ops(op.res, target_ops)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
return False
# Attach the interface implementation to the op.
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
# CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
out = OneOpInOneOpOut(func_handle).result
# CHECK: Output handle from OneOpInOneOpOut
# CHECK-NEXT: func.func @name_of_func
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
transform.YieldOp([out])
return schedule
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
@run
def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
"""Tests an interface implementation using the rewriter to modify the IR.
Checks that `arith.constant` ops are replaced by `index.constant` ops and
that the results are correctly updated.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OneOpInOneOpOut,
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
result_ops = []
for target_op in state.get_payload_ops(op.target):
with ir.InsertionPoint(target_op):
index_version = index.constant(target_op.value.value)
result_ops.append(index_version.owner)
rewriter.replace_op(target_op, [index_version])
results.set_ops(op.res, result_ops)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
return False
# Attach the interface implementation to the op.
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.consumes_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.modifies_payload(effects)
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
csts_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.constant"]
).result
# CHECK: Before replacement:
# CHECK-NOT: index.constant
# CHECK-DAG: arith.constant 42 : i32
# CHECK-DAG: arith.constant 24 : i32
transform.PrintOp(target=func_handle, name="Before replacement:")
out = OneOpInOneOpOut(csts_handle).result
# CHECK: After replacement:
# CHECK-NOT: arith.constant
# CHECK-DAG: index.constant 42
# CHECK-DAG: index.constant 24
transform.PrintOp(target=func_handle, name="After replacement:")
# CHECK: Output handle from OneOpInOneOpOut:
# CHECK-NEXT: index.constant 42
# CHECK-NEXT: index.constant 24
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut:")
transform.YieldOp([out])
return schedule
class OpValParamInParamOpValOut(
MyTransform.Operation, name="op_val_param_in_param_op_val_out"
):
# operands
op_arg: ext.Operand[transform.AnyOpType]
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
param_res: ext.Result[transform.AnyParamType[()]] = ext.infer_result()
op_res: ext.Result[transform.AnyOpType[()]] = ext.infer_result()
value_res: ext.Result[transform.AnyValueType[()]] = ext.infer_result()
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@run
def OpValParamInParamOpValOutTransformOpInterface():
"""Tests an interface implementation involving Op, Value, and Param types.
Checks that payload ops, values, and parameters are correctly permuted and
propagated and accessible from the (permuted) result handles.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpValParamInParamOpValOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
ops = state.get_payload_ops(op.op_arg)
values = state.get_payload_values(op.val_arg)
params = state.get_params(op.param_arg)
print(
f"OpValParamInParamOpValOutTransformOpInterface: ops={len(ops)}, values={len(values)}, params={len(params)}"
)
results.set_params(op.param_res, params)
results.set_ops(op.op_res, ops)
results.set_values(op.value_res, values)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
return False
TransformOpInterfaceFallbackModel.attach(OpValParamInParamOpValOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
MemoryEffectsOpInterfaceFallbackModel.attach(
OpValParamInParamOpValOut.OPERATION_NAME
)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
addf_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.addf"]
).result
func_and_addf = transform.MergeHandlesOp([func_handle, addf_handle])
value_handle = transform.GetResultOp(
AnyValueType.get(), addf_handle, [0]
).result
param_handle = transform.ParamConstantOp(
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
).param
# CHECK: OpValParamInParamOpValOutTransformOpInterface: ops=2, values=1, params=1
op_val_param_op = OpValParamInParamOpValOut(
func_and_addf, value_handle, param_handle
)
# CHECK: Ops passed through OpValParamInParamOpValOut:
# CHECK-NEXT: func.func
# CHECK: arith.addf
transform.PrintOp(
target=op_val_param_op.op_res,
name="Ops passed through OpValParamInParamOpValOut:",
)
# CHECK: Ops defining values passed through OpValParamInParamOpValOut:
# CHECK-NEXT: arith.addf
addf_as_res = transform.GetDefiningOp(
transform.AnyOpType.get(), op_val_param_op.value_res
).result
transform.PrintOp(
target=addf_as_res,
name="Ops defining values passed through OpValParamInParamOpValOut:",
)
# CHECK: Parameter passed through OpValParamInParamOpValOut:
# CHECK-NEXT: 42 : i32
PrintParamOp(
op_val_param_op.param_res,
name=ir.StringAttr.get(
"Parameter passed through OpValParamInParamOpValOut:"
),
)
transform.YieldOp([op_val_param_op.op_res])
named_seq.verify()
return schedule
class OpsParamsInValuesParamOut(
MyTransform.Operation, name="ops_params_in_values_param_out"
):
# results
values: Sequence[ext.Result[transform.AnyValueType]]
param: ext.Result[transform.AnyParamType]
# operands
ops: Sequence[ext.Operand[transform.AnyOpType]]
params: Sequence[ext.Operand[transform.AnyParamType]]
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
@run
def OpsParamsInValuesParamOutTransformOpInterface():
"""Tests an interface with variadic Op and Param operands and variadic Value results.
Checks correct handling of multiple handles, parameter aggregation, and
result generation.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpsParamsInValuesParamOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
ops_count = 0
value_handles = []
for op_handle in op.ops:
ops = state.get_payload_ops(op_handle)
ops_count += len(ops)
value_handles.append([i for op in ops for i in op.results])
param_count = 0
param_sum = 0
for param_handle in op.params:
params = state.get_params(param_handle)
param_count += len(params)
param_sum += sum(p.value for p in params)
print(
f"OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count={ops_count}, param_count={param_count}"
)
assert len(op.values) == len(op.ops)
for value_res_handle, value_vector in zip(op.values, value_handles):
results.set_values(value_res_handle, value_vector)
results.set_params(
op.param,
[ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OpsParamsInValuesParamOut) -> bool:
return False
TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
MemoryEffectsOpInterfaceFallbackModel.attach(
OpsParamsInValuesParamOut.OPERATION_NAME
)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
csts_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.constant"]
).result
csts_as_param = GetNamedAttributeOp(
csts_handle, attr_name=ir.StringAttr.get("value")
).attr_as_param
param_handle = transform.ParamConstantOp(
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
).param
# CHECK: OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count=3, param_count=3
op = OpsParamsInValuesParamOut(
[transform.AnyValueType.get()] * 2,
transform.AnyParamType.get(),
[func_handle, csts_handle],
[csts_as_param, param_handle],
)
empty_handle = transform.GetDefiningOp(transform.AnyOpType.get(), op.values[0])
# CHECK: Defining op of value result 0
transform.PrintOp(
target=empty_handle.result, name="Defining op of value result 0"
)
# NB: no result on the func.func, so output is expected to be empty
cst1_res, cst2_res = transform.SplitHandleOp(
[transform.AnyValueType.get()] * 2, op.values[1]
).results
cst1_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst1_res)
# CHECK-NEXT: Defining op of first constant
# CHECK-NEXT: arith.constant 42 : i32
transform.PrintOp(
target=cst1_again.result, name="Defining op of first constant"
)
cst2_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst2_res)
# CHECK-NEXT: Defining op of second constant
# CHECK-NEXT: arith.constant 24 : i32
transform.PrintOp(
target=cst2_again.result, name="Defining op of second constant"
)
# CHECK: Sum of params:
# CHECK-NEXT: 189 : i32
PrintParamOp(op.param, name=ir.StringAttr.get("Sum of params:"))
transform.YieldOp([func_handle])
named_seq.verify()
return schedule