Files
Rolf Morel f12fcf030c [MLIR][Transform][Python] transform.foreach wrapper and .owner OpViews (#172228)
Friendlier wrapper for transform.foreach.

To facilitate that friendliness, makes it so that OpResult.owner returns
the relevant OpView instead of Operation. For good measure, also changes
Value.owner to return OpView instead of Operation, thereby ensuring
consistency. That is, makes it is so that all op-returning .owner
accessors return OpView (and thereby give access to all goodies
available on registered OpViews.)

Reland of #171544 due to fixup for integration test.
2025-12-14 22:10:31 +00:00

456 lines
18 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import pdl as transform_pdl
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f(module)
print(module)
return f
@run
def testTypes(module: Module):
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
print(any_op)
# CHECK: !transform.any_param
any_param = transform.AnyParamType.get()
print(any_param)
# CHECK: !transform.any_value
any_value = transform.AnyValueType.get()
print(any_value)
# CHECK: !transform.op<"foo.bar">
# CHECK: foo.bar
concrete_op = transform.OperationType.get("foo.bar")
print(concrete_op)
print(concrete_op.operation_name)
# CHECK: !transform.param<i32>
# CHECK: i32
param = transform.ParamType.get(IntegerType.get_signless(32))
print(param)
print(param.type)
@run
def testSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
res = transform.CastOp(transform.AnyOpType.get(), sequence.bodyTarget)
res2 = transform.cast(transform.any_op_t(), res.result)
transform.YieldOp([res2])
# CHECK-LABEL: TEST: testSequenceOp
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: %[[RES:.+]] = cast %[[ARG0]] : !transform.any_op to !transform.any_op
# CHECK: %[[RES2:.+]] = cast %[[RES]] : !transform.any_op to !transform.any_op
# CHECK: yield %[[RES2]] : !transform.any_op
# CHECK: }
@run
def testSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
transform.YieldOp([sequence.bodyTarget])
# CHECK-LABEL: TEST: testSequenceOp
# CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: yield %[[ARG0]] : !transform.any_op
# CHECK: }
@run
def testNestedSequenceOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
)
with InsertionPoint(nested.body):
doubly_nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
nested.bodyTarget,
)
with InsertionPoint(doubly_nested.body):
transform.YieldOp([doubly_nested.bodyTarget])
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOp
# CHECK: transform.sequence failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: sequence %[[ARG0]] : !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
# CHECK: yield %[[ARG2]] : !transform.any_op
# CHECK: }
# CHECK: }
# CHECK: }
@run
def testSequenceOpWithExtras(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
transform.YieldOp()
# CHECK-LABEL: TEST: testSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
sequence = transform.sequence(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
transform.yield_()
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
@run
def testNestedSequenceOpWithExtras(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
[transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
)
with InsertionPoint(sequence.body):
nested = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
sequence.bodyTarget,
sequence.bodyExtraArgs,
)
with InsertionPoint(nested.body):
transform.YieldOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
# CHECK: transform.sequence failures(propagate)
# CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
# CHECK: sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
@run
def testTransformPDLOps(module: Module):
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
withPdl.bodyTarget,
)
with InsertionPoint(sequence.body):
match = transform_pdl.PDLMatchOp(
transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
)
transform.YieldOp(match)
# CHECK-LABEL: TEST: testTransformPDLOps
# CHECK: transform.with_pdl_patterns {
# CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
# CHECK: = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
# CHECK: yield %[[RES]] : !transform.any_op
# CHECK: }
# CHECK: }
@run
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
named_sequence = transform.NamedSequenceOp(
"__transform_main",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
arg_attrs=[{"transform.consumed": UnitAttr.get()}],
)
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
# CHECK-LABEL: TEST: testNamedSequenceOp
# CHECK: module attributes {transform.with_named_sequence} {
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
# CHECK: yield %[[ARG0]] : !transform.any_op
named_sequence = transform.named_sequence(
"other_seq",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
arg_attrs=[{"transform.consumed": UnitAttr.get()}],
)
with InsertionPoint(named_sequence.body):
transform.yield_([named_sequence.bodyTarget])
# CHECK: transform.named_sequence @other_seq(%[[ARG1:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
# CHECK: yield %[[ARG1]] : !transform.any_op
@run
def testGetParentOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
transform.GetParentOp(
transform.AnyOpType.get(),
sequence.bodyTarget,
isolated_from_above=True,
nth_parent=2,
)
transform.get_parent_op(
transform.AnyOpType.get(),
sequence.bodyTarget,
isolated_from_above=True,
nth_parent=2,
allow_empty_results=True,
op_name="func.func",
deduplicate=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testGetParentOp
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
# CHECK: = get_parent_op %[[ARG1]] {allow_empty_results, deduplicate, isolated_from_above, nth_parent = 2 : i64, op_name = "func.func"}
@run
def testMergeHandlesOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
res = transform.MergeHandlesOp([sequence.bodyTarget])
transform.merge_handles([res.result], deduplicate=True)
transform.YieldOp()
# CHECK-LABEL: TEST: testMergeHandlesOp
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
# CHECK: %[[RES1:.+]] = merge_handles %[[ARG1]] : !transform.any_op
# CHECK: = merge_handles deduplicate %[[RES1]] : !transform.any_op
@run
def testApplyPatternsOpCompact(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
transform.ApplyCanonicalizationPatternsOp()
with InsertionPoint(
transform.apply_patterns(
sequence.bodyTarget,
apply_cse=True,
max_iterations=3,
max_num_rewrites=5,
).patterns
):
transform.ApplyCanonicalizationPatternsOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testApplyPatternsOpCompact
# CHECK: apply_patterns to
# CHECK: transform.apply_patterns.canonicalization
# CHECK: } : !transform.any_op
# CHECK: apply_patterns to
# CHECK: transform.apply_patterns.canonicalization
# CHECK: } {apply_cse, max_iterations = 3 : i64, max_num_rewrites = 5 : i64} : !transform.any_op
@run
def testApplyPatternsOpWithType(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("test.dummy"),
)
with InsertionPoint(sequence.body):
with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
transform.ApplyCanonicalizationPatternsOp()
transform.YieldOp()
# CHECK-LABEL: TEST: testApplyPatternsOp
# CHECK: apply_patterns to
# CHECK: transform.apply_patterns.canonicalization
# CHECK: !transform.op<"test.dummy">
@run
def testReplicateOp(module: Module):
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
)
with InsertionPoint(sequence.body):
m1 = transform_pdl.PDLMatchOp(
transform.AnyOpType.get(), sequence.bodyTarget, "first"
)
m2 = transform_pdl.PDLMatchOp(
transform.AnyOpType.get(), sequence.bodyTarget, "second"
)
transform.ReplicateOp(m1, [m2])
transform.replicate(m1, [m2])
transform.YieldOp()
# CHECK-LABEL: TEST: testReplicateOp
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
@run
def testApplyRegisteredPassOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
)
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
# CHECK-SAME: with options = {"top-down" = false}
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
mod.result,
"canonicalize",
options={"top-down": BoolAttr.get(False)},
)
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
max_iter = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 10),
)
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
max_rewrites = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 1),
)
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
# NB: MLIR has sorted the dict lexicographically by key:
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
# CHECK-SAME: "test-convergence" = true,
# CHECK-SAME: "top-down" = false}
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
mod = transform.apply_registered_pass(
transform.AnyOpType.get(),
mod,
"canonicalize",
options={
"top-down": BoolAttr.get(False),
"max-iterations": max_iter,
"test-convergence": True,
"max-rewrites": max_rewrites,
},
)
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
# CHECK-SAME: with options = {"exclude" = ["a", "b"]}
# CHECK-SAME: to %{{.*}} : (!transform.any_op) -> !transform.any_op
mod = transform.apply_registered_pass(
transform.AnyOpType.get(),
mod,
"symbol-privatize",
options={"exclude": ("a", "b")},
)
# CHECK: %[[SYMBOL_A:.+]] = transform.param.constant
symbol_a = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a")
)
# CHECK: %[[SYMBOL_B:.+]] = transform.param.constant
symbol_b = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("b")
)
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
# CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
mod = transform.apply_registered_pass(
transform.AnyOpType.get(),
mod,
"symbol-privatize",
options={"exclude": (symbol_a, symbol_b)},
)
transform.YieldOp()
# CHECK-LABEL: TEST: testForeachOp
@run
def testForeachOp(module: Module):
# CHECK: transform.sequence
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[transform.AnyOpType.get()],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
# CHECK: {{.*}} = foreach %{{.*}} : !transform.any_op -> !transform.any_op
foreach1 = transform.ForeachOp(
(transform.AnyOpType.get(),), (sequence.bodyTarget,)
)
with InsertionPoint(foreach1.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_(foreach1.bodyTargets)
a_val = transform.get_operand(
transform.AnyValueType.get(), foreach1.result, [0]
)
a_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("a_param")
)
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} : !transform.any_op, !transform.any_value, !transform.any_param -> !transform.any_value, !transform.any_param
foreach2 = transform.foreach(
(transform.AnyValueType.get(), transform.AnyParamType.get()),
(sequence.bodyTarget, a_val, a_param),
)
with InsertionPoint(foreach2.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_value, !transform.any_param
transform.yield_(foreach2.owner.bodyTargets[1:3])
another_param = transform.param_constant(
transform.AnyParamType.get(), StringAttr.get("another_param")
)
params = transform.merge_handles([a_param, another_param])
# CHECK: {{.*}} = foreach %{{.*}}, %{{.*}}, %{{.*}} with_zip_shortest : !transform.any_op, !transform.any_param, !transform.any_param -> !transform.any_op
foreach3 = transform.foreach(
(transform.AnyOpType.get(),),
(foreach1.result, foreach2[1], params),
with_zip_shortest=True,
)
with InsertionPoint(foreach3.owner.body):
# CHECK: transform.yield {{.*}} : !transform.any_op
transform.yield_((foreach3.owner.bodyTargets[0],))
transform.yield_((foreach3,))