Verifier was insisting on `!transform.param<...>` too early and hence crashed on `!transform.any_param`.
69 lines
2.8 KiB
Python
69 lines
2.8 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir import ir
|
|
from mlir.dialects import transform, smt
|
|
from mlir.dialects.transform import smt as transform_smt
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
with ir.Context(), ir.Location.unknown():
|
|
module = ir.Module.create()
|
|
with ir.InsertionPoint(module.body):
|
|
sequence = transform.SequenceOp(
|
|
transform.FailurePropagationMode.Propagate,
|
|
[],
|
|
transform.AnyOpType.get(),
|
|
)
|
|
with ir.InsertionPoint(sequence.body):
|
|
f(sequence.bodyTarget)
|
|
transform.YieldOp()
|
|
print(module)
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testConstrainParamsOp
|
|
@run
|
|
def testConstrainParamsOp(target):
|
|
c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
|
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
|
|
symbolic_value_as_param = transform.ParamConstantOp(
|
|
transform.AnyParamType.get(), c42_attr
|
|
)
|
|
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
|
constrain_params = transform_smt.ConstrainParamsOp(
|
|
[], [symbolic_value_as_param], [smt.IntType.get()]
|
|
)
|
|
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
|
|
with ir.InsertionPoint(constrain_params.body):
|
|
symbolic_value_as_smt_var = constrain_params.body.arguments[0]
|
|
# CHECK: %[[C0:.*]] = smt.int.constant 0
|
|
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
|
|
# CHECK: %[[C43:.*]] = smt.int.constant 43
|
|
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
|
|
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
|
|
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var)
|
|
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
|
|
ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43)
|
|
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
|
|
bounded = smt.AndOp([lb, ub])
|
|
# CHECK: smt.assert %[[BOUNDED:.*]]
|
|
smt.AssertOp(bounded)
|
|
smt.YieldOp([])
|
|
|
|
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
|
|
compute_with_params = transform_smt.ConstrainParamsOp(
|
|
[transform.AnyParamType.get()],
|
|
[symbolic_value_as_param],
|
|
[smt.IntType.get()],
|
|
)
|
|
# CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int):
|
|
with ir.InsertionPoint(compute_with_params.body):
|
|
symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
|
|
# CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]]
|
|
twice_symb = smt.IntAddOp(
|
|
[symbolic_value_as_smt_var, symbolic_value_as_smt_var]
|
|
)
|
|
# CHECK: smt.yield %[[TWICE]]
|
|
smt.YieldOp([twice_symb])
|