diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 63166e16b27b..da3f1a28deaa 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -397,3 +397,26 @@ llvm.func @inline_test_return_arity_mismatch_callee(%arg0: f16, %arg1: f16) { %1 = "test.op_with_bitcast_type"(%arg1) : (f16) -> tensor<2xi32> "test.return"(%0, %1) : (tensor<4xf32>, tensor<2xi32>) -> () } + +// Check that a functional_region_op with a multi-block region is inlined +// correctly. Previously the test dialect's handleTerminator(op, Block*) +// was missing, causing an llvm_unreachable when the non-entry block's +// test.return terminator was processed. +// CHECK-LABEL: func @inline_functional_region_multiblock( +func.func @inline_functional_region_multiblock(%arg0: i32) -> i32 { + // CHECK-NOT: call_indirect + // CHECK: arith.addi + // CHECK: arith.addi + // CHECK: cf.br + // CHECK: arith.addi + %fn = "test.functional_region_op"() ({ + ^bb0(%a : i32): + %b = arith.addi %a, %a : i32 + cf.br ^bb1(%b: i32) + ^bb1(%c: i32): + %d = arith.addi %c, %c : i32 + "test.return"(%d) : (i32) -> () + }) : () -> ((i32) -> i32) + %0 = call_indirect %fn(%arg0) : (i32) -> i32 + return %0 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 7ccbe49b2c62..1c9dbe164068 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "TestOps.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -342,6 +343,19 @@ struct TestInlinerInterface : public DialectInlinerInterface { // Transformation Hooks //===--------------------------------------------------------------------===// + /// Handle the given inlined terminator by replacing it with a new operation + /// as necessary (multi-block inlining case: replace test.return with a + /// branch to the successor block that carries the inlined results). + void handleTerminator(Operation *op, Block *newDest) const final { + auto returnOp = dyn_cast(op); + if (!returnOp) + return; + OpBuilder builder(op); + cf::BranchOp::create(builder, op->getLoc(), newDest, + returnOp.getOperands()); + op->erase(); + } + /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {