diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td index ddea3a7eae59..a441fd82546e 100644 --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -65,8 +65,7 @@ def AssertOp : CF_Op<"assert", //===----------------------------------------------------------------------===// def BranchOp : CF_Op<"br", [ - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, Terminator ]> { let summary = "Branch operation"; @@ -115,8 +114,8 @@ def BranchOp : CF_Op<"br", [ def CondBranchOp : CF_Op<"cond_br", [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods< + BranchOpInterface, ["getSuccessorForOperands"]>, WeightedBranchOpInterface, Pure, Terminator]> { let summary = "Conditional branch operation"; let description = [{ @@ -242,8 +241,7 @@ def CondBranchOp def SwitchOp : CF_Op<"switch", [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure, Terminator]> { let summary = "Switch operation"; let description = [{ diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index d32be0c63acc..06fa724e05fa 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -98,15 +98,6 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}], [{ return lhs == rhs; }] >, - InterfaceMethod<[{ - This method is called to returns the operands of this operation that - are passed to the specified successor's block arguments. If the successor - is not valid for this operation, or no operands are forwarded, an empty - ValueRange is returned. - }], - "ValueRange", "getSuccessorForwardOperands", - (ins "Block *":$successor), [{}],[{ return {};}] - >, ]; let verify = [{ diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h index 125a7c6f6f5e..6ca4e13d159a 100644 --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -90,9 +90,6 @@ public: /// corresponding region. LogicalResult initialize(ModuleOp parentModule, Region &parentRegion); - LogicalResult initialize(ModuleOp parentModule, Region &parentRegion, - IRMapping &mapper); - private: /// A custom BFS iterator. The difference between /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index f6eb0f05911b..435c37bc95aa 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -296,12 +296,6 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef) { return getDest(); } -ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) { - if (successor == getDest()) - return getDestOperands(); - return {}; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// @@ -589,14 +583,6 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { return nullptr; } -ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) { - if (successor == getTrueDest()) - return getTrueOperands(); - else if (successor == getFalseDest()) - return getFalseOperands(); - return {}; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -1048,16 +1034,6 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, .add(context); } -ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) { - if (successor == getDefaultDestination()) - return getDefaultOperands(); - SuccessorRange caseDests = getCaseDestinations(); - auto it = llvm::find(caseDests, successor); - if (it == caseDests.end()) - return {}; - return getCaseOperands(std::distance(caseDests.begin(), it)); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt index b18a4bca04fc..68864e373c99 100644 --- a/mlir/lib/Reducer/CMakeLists.txt +++ b/mlir/lib/Reducer/CMakeLists.txt @@ -9,7 +9,6 @@ add_mlir_library(MLIRReduce MLIRPass MLIRRewrite MLIRTransformUtils - MLIRControlFlowDialect DEPENDS MLIRReducerIncGen diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp index 897aae0becf3..11aeaf77b464 100644 --- a/mlir/lib/Reducer/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -45,16 +45,6 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule, return success(); } -LogicalResult ReductionNode::initialize(ModuleOp parentModule, - Region &targetRegion, - IRMapping &mapper) { - module = cast(parentModule->clone(mapper)); - // Use the first block of targetRegion to locate the cloned region. - Block *block = mapper.lookup(&*targetRegion.begin()); - region = block->getParent(); - return success(); -} - /// If we haven't explored any variants from this node, we will create N /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// max element in `ranges` and create 2 new variants for each call. diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 12358f7d7168..83497143d966 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -14,10 +14,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/DialectInterface.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionPatternInterface.h" @@ -27,9 +24,6 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Allocator.h" -#include "llvm/Support/DebugLog.h" - -#define DEBUG_TYPE "reduction-tree" namespace mlir { #define GEN_PASS_DEF_REDUCTIONTREEPASS @@ -190,112 +184,6 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion, return failure(); } -// Returns the first branching terminator (cond_br, switch, etc.) found in the -// region. -static Operation *getBranchTerminatorInRegion(Region ®ion) { - for (Block &block : region.getBlocks()) { - if (block.getNumSuccessors() > 1) - return block.getTerminator(); - } - return {}; -} - -/// Reduces the control flow in a region by iteratively forcing branching -/// terminators to point to a single successor. It evaluates each potential -/// branch path and commits the reduction that results in the smallest -/// "interesting" module. -static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module, - Region ®ion, - const Tester &test) { - std::pair initStatus = - test.isInteresting(module); - - // While exploring the reduction tree, we always branch from an interesting - // node. Thus the root node must be interesting. - if (initStatus.first != Tester::Interestingness::True) - return module.emitWarning() << "uninterested module will not be reduced"; - llvm::SpecificBumpPtrAllocator allocator; - - // We set the simplification level to Aggressive to enable block merging. - GreedyRewriteConfig config; - config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive); - config.setUseTopDownTraversal(true); - - // Populate canonicalization patterns for cf ops. When all targets of a - // 'cf.cond_br' or 'cf.switch' point to the same block, they will be - // canonicalized into a 'cf.br'. - auto context = region.getContext(); - RewritePatternSet patterns(context); - cf::BranchOp::getCanonicalizationPatterns(patterns, context); - cf::CondBranchOp::getCanonicalizationPatterns(patterns, context); - cf::SwitchOp::getCanonicalizationPatterns(patterns, context); - FrozenRewritePatternSet fPatterns = std::move(patterns); - - ReductionNode *smallestNode = nullptr; - mlir::OpBuilder b(context); - while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) { - size_t numSuccessor = branchTerminator->getNumSuccessors(); - std::vector ranges{ - {0, std::distance(region.op_begin(), region.op_end())}}; - // Iterate through each successor of the branching terminator to try - // reducing the control flow to a single-path execution. - int branchIdx = -1; - for (int i = 0, e = numSuccessor; i < e; ++i) { - // We allocate memory on the heap because the object will be assigned to - // 'smallestNode'. - ReductionNode *root = allocator.Allocate(); - new (root) ReductionNode(nullptr, ranges, allocator); - mlir::IRMapping mapper; - if (failed(root->initialize(module, region, mapper))) - llvm_unreachable("unexpected initialization failure"); - Operation *tergetTerminator = mapper.lookup(branchTerminator); - Block *selectedBlock = tergetTerminator->getSuccessor(i); - auto branchOp = cast(tergetTerminator); - ValueRange selectedBlockOperands = - branchOp.getSuccessorForwardOperands(selectedBlock); - b.setInsertionPointAfter(tergetTerminator); - cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock, - selectedBlockOperands); - tergetTerminator->erase(); - - // Apply canonicalization patterns to collapse the now-redundant branches - (void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns, - config); - root->update(test.isInteresting(root->getModule())); - - // Track the smallest "interesting" version of the IR found so far. - if (root->isInteresting() == Tester::Interestingness::True && - (smallestNode == nullptr || - root->getSize() < smallestNode->getSize())) { - smallestNode = root; - branchIdx = i; - } - } - - // If an interesting reduced branch was found, commit the change to the - // original region and re-apply patterns for a final cleanup. - if (branchIdx != -1) { - Block *selectedBlock = branchTerminator->getSuccessor(branchIdx); - auto branchOp = cast(branchTerminator); - ValueRange selectedBlockOperands = - branchOp.getSuccessorForwardOperands(selectedBlock); - b.setInsertionPointAfter(branchTerminator); - cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock, - selectedBlockOperands); - branchTerminator->erase(); - (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config); - } - } - - // If no branching terminators were found (skipping the while loop), - // there might still be opportunities for linear block merging or - // We apply patterns here as a final cleanup to ensure the region is fully - // simplified. - if (smallestNode == nullptr) - (void)applyPatternsGreedily(region.getParentOp(), fPatterns, config); - return success(); -} - template static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, @@ -308,8 +196,6 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, if (succeeded(eraseAllOpsInRegion(module, region, test))) return success(); - (void)eraseRedundantBlocksInRegion(module, region, test); - // In the second phase, we don't apply any patterns so that we only select the // range of operations to keep to the module stay interesting. if (failed(findOptimal(module, region, /*patterns=*/{}, test, diff --git a/mlir/test/mlir-reduce/reduction-tree.mlir b/mlir/test/mlir-reduce/reduction-tree.mlir index b053a111e9a1..2aee89741b42 100644 --- a/mlir/test/mlir-reduce/reduction-tree.mlir +++ b/mlir/test/mlir-reduce/reduction-tree.mlir @@ -58,68 +58,3 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func.func @simple5() { return } - -// ----- - -// CHECK-LABEL: func @br_reduction -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { -func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - cf.br ^bb3(%arg1 : memref<2xf32>) -^bb2: - %0 = memref.alloc() : memref<2xf32> - cf.br ^bb3(%0 : memref<2xf32>) -^bb3(%1: memref<2xf32>): - "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) - -// ----- - -// CHECK-LABEL: func @br_reduction_loop -// CHECK-SAME: %[[ARG0:.*]]: i1, -// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { -func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - // select ^bb2 - cf.cond_br %arg0, ^bb1, ^bb2 -^bb1: - cf.br ^bb3(%arg1 : memref<2xf32>) -^bb2: - %0 = memref.alloc() : memref<2xf32> - cf.br ^bb3(%0 : memref<2xf32>) -^bb3(%1: memref<2xf32>): - "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - // select ^bb4 - cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4 -^bb4: - return -} -// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]]) - -// ----- - -// CHECK-LABEL: func @switch_reduction -// CHECK-SAME: %[[ARG0:.*]]: i32, -// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) { -func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cf.switch %arg0 : i32, [ - default: ^bb3(%arg1 : memref<2xf32>), - 0: ^bb1, - 1: ^bb2 - ] -^bb1: - cf.br ^bb3(%arg1 : memref<2xf32>) -^bb2: - %0 = memref.alloc() : memref<2xf32> - cf.br ^bb3(%0 : memref<2xf32>) -^bb3(%1: memref<2xf32>): - "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () - return -} -// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])