# RUN: %PYTHON %s 2>&1 | FileCheck %s import gc, sys from mlir.ir import * from mlir.passmanager import * from mlir.dialects.builtin import ModuleOp from mlir.dialects import pdl from mlir.rewrite import * def log(*args): print(*args, file=sys.stderr) sys.stderr.flush() def run(f): log("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 def make_pdl_module(): with Location.unknown(): pdl_module = Module.create() with InsertionPoint(pdl_module.body): # Change all arith.addi with index types to arith.muli. @pdl.pattern(benefit=1, sym_name="addi_to_mul") def pat(): # Match arith.addi with index types. i64_type = pdl.TypeOp(IntegerType.get_signless(64)) operand0 = pdl.OperandOp(i64_type) operand1 = pdl.OperandOp(i64_type) op0 = pdl.OperationOp( name="arith.addi", args=[operand0, operand1], types=[i64_type] ) # Replace the matched op with arith.muli. @pdl.rewrite() def rew(): newOp = pdl.OperationOp( name="arith.muli", args=[operand0, operand1], types=[i64_type] ) pdl.ReplaceOp(op0, with_op=newOp) return pdl_module # CHECK-LABEL: TEST: testCustomPass @run def testCustomPass(): with Context(): pdl_module = make_pdl_module() frozen = PDLModule(pdl_module).freeze() module = ModuleOp.parse( r""" module { func.func @add(%a: i64, %b: i64) -> i64 { %sum = arith.addi %a, %b : i64 return %sum : i64 } } """ ) def custom_pass_1(op, pass_): print("hello from pass 1!!!", file=sys.stderr) class CustomPass2: def __call__(self, op, pass_): apply_patterns_and_fold_greedily(op, frozen) custom_pass_2 = CustomPass2() pm = PassManager("any") pm.enable_ir_printing() # CHECK: hello from pass 1!!! # CHECK-LABEL: Dump After custom_pass_1 pm.add(custom_pass_1) # CHECK-LABEL: Dump After CustomPass2 # CHECK: arith.muli pm.add(custom_pass_2, "CustomPass2") # CHECK-LABEL: Dump After ArithToLLVMConversionPass # CHECK: llvm.mul pm.add("convert-arith-to-llvm") pm.run(module) # test signal_pass_failure def custom_pass_that_fails(op, pass_): print("hello from pass that fails") pass_.signal_pass_failure() pm = PassManager("any") pm.add(custom_pass_that_fails, "CustomPassThatFails") # CHECK: hello from pass that fails # CHECK: caught exception: Failure while executing pass pipeline try: pm.run(module) except Exception as e: print(f"caught exception: {e}")