# RUN: %PYTHON %s 2>&1 | FileCheck %s import gc from mlir.ir import * from mlir.passmanager import * from mlir.dialects.builtin import ModuleOp from mlir.dialects import arith from mlir.rewrite import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() return f # CHECK-LABEL: TEST: testRewritePattern @run def testRewritePattern(): def to_muli(op, rewriter): with rewriter.ip: assert isinstance(op, arith.AddIOp) new_op = arith.muli(op.lhs, op.rhs, loc=op.location) rewriter.replace_op(op, new_op.owner) def constant_1_to_2(op, rewriter): c = op.value.value if c != 1: return True # failed to match with rewriter.ip: new_op = arith.constant(op.type, 2, loc=op.location) rewriter.replace_op(op, [new_op]) with Context(): patterns = RewritePatternSet() patterns.add(arith.AddIOp, to_muli) patterns.add("arith.constant", constant_1_to_2) frozen = patterns.freeze() module = ModuleOp.parse( r""" module { func.func @add(%a: i64, %b: i64) -> i64 { %sum = arith.addi %a, %b : i64 return %sum : i64 } } """ ) apply_patterns_and_fold_greedily(module, frozen) # CHECK: %0 = arith.muli %arg0, %arg1 : i64 # CHECK: return %0 : i64 print(module) module = ModuleOp.parse( r""" module { func.func @const() -> (i64, i64) { %0 = arith.constant 1 : i64 %1 = arith.constant 3 : i64 return %0, %1 : i64, i64 } } """ ) apply_patterns_and_fold_greedily(module, frozen) # CHECK: %c2_i64 = arith.constant 2 : i64 # CHECK: %c3_i64 = arith.constant 3 : i64 # CHECK: return %c2_i64, %c3_i64 : i64, i64 print(module) module = ModuleOp.parse( r""" module { func.func @add(%a: i64, %b: i64) -> i64 { %sum = arith.addi %a, %b : i64 return %sum : i64 } } """ ) walk_and_apply_patterns(module, frozen) # CHECK: %0 = arith.muli %arg0, %arg1 : i64 # CHECK: return %0 : i64 print(module) # CHECK-LABEL: TEST: testGreedyRewriteConfigCreation @run def testGreedyRewriteConfigCreation(): # Test basic config creation and destruction config = GreedyRewriteConfig() # CHECK: Config created successfully print("Config created successfully") # CHECK-LABEL: TEST: testGreedyRewriteConfigGetters @run def testGreedyRewriteConfigGetters(): config = GreedyRewriteConfig() # Set some values config.max_iterations = 5 config.max_num_rewrites = 50 config.use_top_down_traversal = True config.enable_folding = False config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE config.enable_constant_cse = True # Test all getter methods and print results # CHECK: max_iterations: 5 max_iterations = config.max_iterations print(f"max_iterations: {max_iterations}") # CHECK: max_rewrites: 50 max_rewrites = config.max_num_rewrites print(f"max_rewrites: {max_rewrites}") # CHECK: use_top_down: True use_top_down = config.use_top_down_traversal print(f"use_top_down: {use_top_down}") # CHECK: folding_enabled: False folding_enabled = config.enable_folding print(f"folding_enabled: {folding_enabled}") # CHECK: strictness: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS strictness = config.strictness print(f"strictness: {strictness}") # CHECK: region_level: GreedySimplifyRegionLevel.AGGRESSIVE region_level = config.region_simplification_level print(f"region_level: {region_level}") # CHECK: cse_enabled: True cse_enabled = config.enable_constant_cse print(f"cse_enabled: {cse_enabled}") # CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum @run def testGreedyRewriteStrictnessEnum(): config = GreedyRewriteConfig() # Test ANY_OP # CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP config.strictness = GreedyRewriteStrictness.ANY_OP strictness = config.strictness print(f"strictness ANY_OP: {strictness}") # Test EXISTING_AND_NEW_OPS # CHECK: strictness EXISTING_AND_NEW_OPS: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS strictness = config.strictness print(f"strictness EXISTING_AND_NEW_OPS: {strictness}") # Test EXISTING_OPS # CHECK: strictness EXISTING_OPS: GreedyRewriteStrictness.EXISTING_OPS config.strictness = GreedyRewriteStrictness.EXISTING_OPS strictness = config.strictness print(f"strictness EXISTING_OPS: {strictness}") # CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum @run def testGreedySimplifyRegionLevelEnum(): config = GreedyRewriteConfig() # Test DISABLED # CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED config.region_simplification_level = GreedySimplifyRegionLevel.DISABLED level = config.region_simplification_level print(f"region_level DISABLED: {level}") # Test NORMAL # CHECK: region_level NORMAL: GreedySimplifyRegionLevel.NORMAL config.region_simplification_level = GreedySimplifyRegionLevel.NORMAL level = config.region_simplification_level print(f"region_level NORMAL: {level}") # Test AGGRESSIVE # CHECK: region_level AGGRESSIVE: GreedySimplifyRegionLevel.AGGRESSIVE config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE level = config.region_simplification_level print(f"region_level AGGRESSIVE: {level}") # CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig @run def testRewriteWithGreedyRewriteConfig(): def constant_1_to_2(op, rewriter): c = op.value.value if c != 1: return True # failed to match with rewriter.ip: new_op = arith.constant(op.type, 2, loc=op.location) rewriter.replace_op(op, [new_op]) with Context(): patterns = RewritePatternSet() patterns.add(arith.ConstantOp, constant_1_to_2) frozen = patterns.freeze() module = ModuleOp.parse( r""" module { func.func @const() -> (i64, i64) { %0 = arith.constant 1 : i64 %1 = arith.constant 1 : i64 return %0, %1 : i64, i64 } } """ ) config = GreedyRewriteConfig() config.enable_constant_cse = False apply_patterns_and_fold_greedily(module, frozen, config) # CHECK: %c2_i64 = arith.constant 2 : i64 # CHECK: %c2_i64_0 = arith.constant 2 : i64 # CHECK: return %c2_i64, %c2_i64_0 : i64, i64 print(module) config = GreedyRewriteConfig() config.enable_constant_cse = True apply_patterns_and_fold_greedily(module, frozen, config) # CHECK: %c2_i64 = arith.constant 2 : i64 # CHECK: return %c2_i64, %c2_i64 : i64 print(module) @run def testConversionPattern(): from mlir.dialects import smt def convert_int(t): if isinstance(t, IntegerType): return smt.IntType.get() converter = TypeConverter() converter.add_conversion(convert_int) def convert_constant(op, adaptor, type_converter, rewriter): assert isinstance(op, arith.ConstantOp) assert isinstance(adaptor, arith.ConstantOpAdaptor) with rewriter.ip: new_op = smt.IntConstantOp(op.value, loc=op.location) rewriter.replace_op(op, new_op) def convert_addi(op, adaptor, type_converter, rewriter): assert isinstance(op, arith.AddIOp) assert isinstance(adaptor, arith.AddIOpAdaptor) with rewriter.ip: new_op = smt.IntAddOp([adaptor.lhs, adaptor.rhs], loc=op.location) rewriter.replace_op(op, new_op) def convert_muli(op, adaptor, type_converter, rewriter): assert isinstance(op, arith.MulIOp) assert isinstance(adaptor, arith.MulIOpAdaptor) with rewriter.ip: new_op = smt.IntMulOp([adaptor.lhs, adaptor.rhs], loc=op.location) rewriter.replace_op(op, new_op) with Context(): patterns = RewritePatternSet() patterns.add_conversion(arith.ConstantOp, convert_constant, converter) patterns.add_conversion(arith.AddIOp, convert_addi, converter) patterns.add_conversion(arith.MulIOp, convert_muli, converter) module = ModuleOp.parse( r""" module { func.func @f(%0: i64) -> i64 { %1 = arith.constant 3 : i64 %2 = arith.addi %0, %1 : i64 %3 = arith.muli %2, %1 : i64 return %3 : i64 } } """ ) target = ConversionTarget() target.add_legal_dialect(smt._Dialect) target.add_illegal_op(arith.ConstantOp, arith.AddIOp, arith.MulIOp) frozen = patterns.freeze() config = ConversionConfig() config.build_materializations = False apply_partial_conversion(module, target, frozen, config) assert module.operation.verify() # CHECK: func.func @f(%arg0: i64) -> i64 { # CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : i64 to !smt.int # CHECK: %c3 = smt.int.constant 3 # CHECK: %1 = smt.int.add %0, %c3 # CHECK: %2 = smt.int.mul %1, %c3 # CHECK: %3 = builtin.unrealized_conversion_cast %2 : !smt.int to i64 # CHECK: return %3 : i64 # CHECK: } print(module) module = ModuleOp.parse( r""" module { func.func @f(%0: i64) -> i64 { %1 = arith.constant 3 : i64 %2 = arith.addi %0, %1 : i64 %3 = arith.muli %2, %1 : i64 return %3 : i64 } } """ ) try: apply_partial_conversion(module, target, frozen) except MLIRError as e: # CHECK: caught exception: partial conversion failed # CHECK: failed to legalize unresolved materialization print("caught exception:", e) t1 = converter.convert_type(IntegerType.get_signless(64)) # CHECK: IntType print(type(t1)) # CHECK: !smt.int print(str(t1)) t2 = converter.convert_type(F32Type.get()) # CHECK: None print(t2)