330 lines
11 KiB
Python
330 lines
11 KiB
Python
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
import gc
|
|
from mlir.ir import *
|
|
from mlir.passmanager import *
|
|
from mlir.dialects.builtin import ModuleOp
|
|
from mlir.dialects import arith
|
|
from mlir.rewrite import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testRewritePattern
|
|
@run
|
|
def testRewritePattern():
|
|
def to_muli(op, rewriter):
|
|
with rewriter.ip:
|
|
assert isinstance(op, arith.AddIOp)
|
|
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
|
|
rewriter.replace_op(op, new_op.owner)
|
|
|
|
def constant_1_to_2(op, rewriter):
|
|
c = op.value.value
|
|
if c != 1:
|
|
return True # failed to match
|
|
with rewriter.ip:
|
|
new_op = arith.constant(op.type, 2, loc=op.location)
|
|
rewriter.replace_op(op, [new_op])
|
|
|
|
with Context():
|
|
patterns = RewritePatternSet()
|
|
patterns.add(arith.AddIOp, to_muli)
|
|
patterns.add("arith.constant", constant_1_to_2)
|
|
frozen = patterns.freeze()
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @add(%a: i64, %b: i64) -> i64 {
|
|
%sum = arith.addi %a, %b : i64
|
|
return %sum : i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
|
|
# CHECK: return %0 : i64
|
|
print(module)
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @const() -> (i64, i64) {
|
|
%0 = arith.constant 1 : i64
|
|
%1 = arith.constant 3 : i64
|
|
return %0, %1 : i64, i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
# CHECK: %c2_i64 = arith.constant 2 : i64
|
|
# CHECK: %c3_i64 = arith.constant 3 : i64
|
|
# CHECK: return %c2_i64, %c3_i64 : i64, i64
|
|
print(module)
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @add(%a: i64, %b: i64) -> i64 {
|
|
%sum = arith.addi %a, %b : i64
|
|
return %sum : i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
walk_and_apply_patterns(module, frozen)
|
|
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
|
|
# CHECK: return %0 : i64
|
|
print(module)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testGreedyRewriteConfigCreation
|
|
@run
|
|
def testGreedyRewriteConfigCreation():
|
|
# Test basic config creation and destruction
|
|
config = GreedyRewriteConfig()
|
|
# CHECK: Config created successfully
|
|
print("Config created successfully")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testGreedyRewriteConfigGetters
|
|
@run
|
|
def testGreedyRewriteConfigGetters():
|
|
config = GreedyRewriteConfig()
|
|
|
|
# Set some values
|
|
config.max_iterations = 5
|
|
config.max_num_rewrites = 50
|
|
config.use_top_down_traversal = True
|
|
config.enable_folding = False
|
|
config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
|
|
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
|
|
config.enable_constant_cse = True
|
|
|
|
# Test all getter methods and print results
|
|
# CHECK: max_iterations: 5
|
|
max_iterations = config.max_iterations
|
|
print(f"max_iterations: {max_iterations}")
|
|
# CHECK: max_rewrites: 50
|
|
max_rewrites = config.max_num_rewrites
|
|
print(f"max_rewrites: {max_rewrites}")
|
|
# CHECK: use_top_down: True
|
|
use_top_down = config.use_top_down_traversal
|
|
print(f"use_top_down: {use_top_down}")
|
|
# CHECK: folding_enabled: False
|
|
folding_enabled = config.enable_folding
|
|
print(f"folding_enabled: {folding_enabled}")
|
|
# CHECK: strictness: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
|
|
strictness = config.strictness
|
|
print(f"strictness: {strictness}")
|
|
# CHECK: region_level: GreedySimplifyRegionLevel.AGGRESSIVE
|
|
region_level = config.region_simplification_level
|
|
print(f"region_level: {region_level}")
|
|
# CHECK: cse_enabled: True
|
|
cse_enabled = config.enable_constant_cse
|
|
print(f"cse_enabled: {cse_enabled}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testGreedyRewriteStrictnessEnum
|
|
@run
|
|
def testGreedyRewriteStrictnessEnum():
|
|
config = GreedyRewriteConfig()
|
|
|
|
# Test ANY_OP
|
|
# CHECK: strictness ANY_OP: GreedyRewriteStrictness.ANY_OP
|
|
config.strictness = GreedyRewriteStrictness.ANY_OP
|
|
strictness = config.strictness
|
|
print(f"strictness ANY_OP: {strictness}")
|
|
|
|
# Test EXISTING_AND_NEW_OPS
|
|
# CHECK: strictness EXISTING_AND_NEW_OPS: GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
|
|
config.strictness = GreedyRewriteStrictness.EXISTING_AND_NEW_OPS
|
|
strictness = config.strictness
|
|
print(f"strictness EXISTING_AND_NEW_OPS: {strictness}")
|
|
|
|
# Test EXISTING_OPS
|
|
# CHECK: strictness EXISTING_OPS: GreedyRewriteStrictness.EXISTING_OPS
|
|
config.strictness = GreedyRewriteStrictness.EXISTING_OPS
|
|
strictness = config.strictness
|
|
print(f"strictness EXISTING_OPS: {strictness}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testGreedySimplifyRegionLevelEnum
|
|
@run
|
|
def testGreedySimplifyRegionLevelEnum():
|
|
config = GreedyRewriteConfig()
|
|
|
|
# Test DISABLED
|
|
# CHECK: region_level DISABLED: GreedySimplifyRegionLevel.DISABLED
|
|
config.region_simplification_level = GreedySimplifyRegionLevel.DISABLED
|
|
level = config.region_simplification_level
|
|
print(f"region_level DISABLED: {level}")
|
|
|
|
# Test NORMAL
|
|
# CHECK: region_level NORMAL: GreedySimplifyRegionLevel.NORMAL
|
|
config.region_simplification_level = GreedySimplifyRegionLevel.NORMAL
|
|
level = config.region_simplification_level
|
|
print(f"region_level NORMAL: {level}")
|
|
|
|
# Test AGGRESSIVE
|
|
# CHECK: region_level AGGRESSIVE: GreedySimplifyRegionLevel.AGGRESSIVE
|
|
config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
|
|
level = config.region_simplification_level
|
|
print(f"region_level AGGRESSIVE: {level}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testRewriteWithGreedyRewriteConfig
|
|
@run
|
|
def testRewriteWithGreedyRewriteConfig():
|
|
def constant_1_to_2(op, rewriter):
|
|
c = op.value.value
|
|
if c != 1:
|
|
return True # failed to match
|
|
with rewriter.ip:
|
|
new_op = arith.constant(op.type, 2, loc=op.location)
|
|
rewriter.replace_op(op, [new_op])
|
|
|
|
with Context():
|
|
patterns = RewritePatternSet()
|
|
patterns.add(arith.ConstantOp, constant_1_to_2)
|
|
frozen = patterns.freeze()
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @const() -> (i64, i64) {
|
|
%0 = arith.constant 1 : i64
|
|
%1 = arith.constant 1 : i64
|
|
return %0, %1 : i64, i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
config = GreedyRewriteConfig()
|
|
config.enable_constant_cse = False
|
|
apply_patterns_and_fold_greedily(module, frozen, config)
|
|
# CHECK: %c2_i64 = arith.constant 2 : i64
|
|
# CHECK: %c2_i64_0 = arith.constant 2 : i64
|
|
# CHECK: return %c2_i64, %c2_i64_0 : i64, i64
|
|
print(module)
|
|
|
|
config = GreedyRewriteConfig()
|
|
config.enable_constant_cse = True
|
|
apply_patterns_and_fold_greedily(module, frozen, config)
|
|
# CHECK: %c2_i64 = arith.constant 2 : i64
|
|
# CHECK: return %c2_i64, %c2_i64 : i64
|
|
print(module)
|
|
|
|
|
|
@run
|
|
def testConversionPattern():
|
|
from mlir.dialects import smt
|
|
|
|
def convert_int(t):
|
|
if isinstance(t, IntegerType):
|
|
return smt.IntType.get()
|
|
|
|
converter = TypeConverter()
|
|
converter.add_conversion(convert_int)
|
|
|
|
def convert_constant(op, adaptor, type_converter, rewriter):
|
|
assert isinstance(op, arith.ConstantOp)
|
|
assert isinstance(adaptor, arith.ConstantOpAdaptor)
|
|
with rewriter.ip:
|
|
new_op = smt.IntConstantOp(op.value, loc=op.location)
|
|
rewriter.replace_op(op, new_op)
|
|
|
|
def convert_addi(op, adaptor, type_converter, rewriter):
|
|
assert isinstance(op, arith.AddIOp)
|
|
assert isinstance(adaptor, arith.AddIOpAdaptor)
|
|
with rewriter.ip:
|
|
new_op = smt.IntAddOp([adaptor.lhs, adaptor.rhs], loc=op.location)
|
|
rewriter.replace_op(op, new_op)
|
|
|
|
def convert_muli(op, adaptor, type_converter, rewriter):
|
|
assert isinstance(op, arith.MulIOp)
|
|
assert isinstance(adaptor, arith.MulIOpAdaptor)
|
|
with rewriter.ip:
|
|
new_op = smt.IntMulOp([adaptor.lhs, adaptor.rhs], loc=op.location)
|
|
rewriter.replace_op(op, new_op)
|
|
|
|
with Context():
|
|
patterns = RewritePatternSet()
|
|
patterns.add_conversion(arith.ConstantOp, convert_constant, converter)
|
|
patterns.add_conversion(arith.AddIOp, convert_addi, converter)
|
|
patterns.add_conversion(arith.MulIOp, convert_muli, converter)
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @f(%0: i64) -> i64 {
|
|
%1 = arith.constant 3 : i64
|
|
%2 = arith.addi %0, %1 : i64
|
|
%3 = arith.muli %2, %1 : i64
|
|
return %3 : i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
target = ConversionTarget()
|
|
target.add_legal_dialect(smt._Dialect)
|
|
target.add_illegal_op(arith.ConstantOp, arith.AddIOp, arith.MulIOp)
|
|
|
|
frozen = patterns.freeze()
|
|
config = ConversionConfig()
|
|
config.build_materializations = False
|
|
|
|
apply_partial_conversion(module, target, frozen, config)
|
|
assert module.operation.verify()
|
|
|
|
# CHECK: func.func @f(%arg0: i64) -> i64 {
|
|
# CHECK: %0 = builtin.unrealized_conversion_cast %arg0 : i64 to !smt.int
|
|
# CHECK: %c3 = smt.int.constant 3
|
|
# CHECK: %1 = smt.int.add %0, %c3
|
|
# CHECK: %2 = smt.int.mul %1, %c3
|
|
# CHECK: %3 = builtin.unrealized_conversion_cast %2 : !smt.int to i64
|
|
# CHECK: return %3 : i64
|
|
# CHECK: }
|
|
print(module)
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @f(%0: i64) -> i64 {
|
|
%1 = arith.constant 3 : i64
|
|
%2 = arith.addi %0, %1 : i64
|
|
%3 = arith.muli %2, %1 : i64
|
|
return %3 : i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
try:
|
|
apply_partial_conversion(module, target, frozen)
|
|
except MLIRError as e:
|
|
# CHECK: caught exception: partial conversion failed
|
|
# CHECK: failed to legalize unresolved materialization
|
|
print("caught exception:", e)
|
|
|
|
t1 = converter.convert_type(IntegerType.get_signless(64))
|
|
# CHECK: IntType
|
|
print(type(t1))
|
|
# CHECK: !smt.int
|
|
print(str(t1))
|
|
t2 = converter.convert_type(F32Type.get())
|
|
# CHECK: None
|
|
print(t2)
|