Revert "Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface"" (#190727)

To decouple the BranchOpInterface implementation from the reduction-tree
changes. Reverts llvm/llvm-project#189253,
This commit is contained in:
lonely eagle
2026-04-07 12:23:38 +08:00
committed by GitHub
parent 7349977415
commit 150783e254
8 changed files with 4 additions and 232 deletions

View File

@@ -65,8 +65,7 @@ def AssertOp : CF_Op<"assert",
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator
]> {
let summary = "Branch operation";
@@ -115,8 +114,8 @@ def BranchOp : CF_Op<"br", [
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
DeclareOpInterfaceMethods<
BranchOpInterface, ["getSuccessorForOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
@@ -242,8 +241,7 @@ def CondBranchOp
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
let summary = "Switch operation";
let description = [{

View File

@@ -98,15 +98,6 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }]
>,
InterfaceMethod<[{
This method is called to returns the operands of this operation that
are passed to the specified successor's block arguments. If the successor
is not valid for this operation, or no operands are forwarded, an empty
ValueRange is returned.
}],
"ValueRange", "getSuccessorForwardOperands",
(ins "Block *":$successor), [{}],[{ return {};}]
>,
];
let verify = [{

View File

@@ -90,9 +90,6 @@ public:
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
IRMapping &mapper);
private:
/// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.

View File

@@ -296,12 +296,6 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getDest())
return getDestOperands();
return {};
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
@@ -589,14 +583,6 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}
ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getTrueDest())
return getTrueOperands();
else if (successor == getFalseDest())
return getFalseOperands();
return {};
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
@@ -1048,16 +1034,6 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<SimplifyUniformBlockArguments>(context);
}
ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getDefaultDestination())
return getDefaultOperands();
SuccessorRange caseDests = getCaseDestinations();
auto it = llvm::find(caseDests, successor);
if (it == caseDests.end())
return {};
return getCaseOperands(std::distance(caseDests.begin(), it));
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@@ -9,7 +9,6 @@ add_mlir_library(MLIRReduce
MLIRPass
MLIRRewrite
MLIRTransformUtils
MLIRControlFlowDialect
DEPENDS
MLIRReducerIncGen

View File

@@ -45,16 +45,6 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
return success();
}
LogicalResult ReductionNode::initialize(ModuleOp parentModule,
Region &targetRegion,
IRMapping &mapper) {
module = cast<ModuleOp>(parentModule->clone(mapper));
// Use the first block of targetRegion to locate the cloned region.
Block *block = mapper.lookup(&*targetRegion.begin());
region = block->getParent();
return success();
}
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.

View File

@@ -14,10 +14,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
@@ -27,9 +24,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "reduction-tree"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREEPASS
@@ -190,112 +184,6 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
return failure();
}
// Returns the first branching terminator (cond_br, switch, etc.) found in the
// region.
static Operation *getBranchTerminatorInRegion(Region &region) {
for (Block &block : region.getBlocks()) {
if (block.getNumSuccessors() > 1)
return block.getTerminator();
}
return {};
}
/// Reduces the control flow in a region by iteratively forcing branching
/// terminators to point to a single successor. It evaluates each potential
/// branch path and commits the reduction that results in the smallest
/// "interesting" module.
static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
Region &region,
const Tester &test) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
// node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return module.emitWarning() << "uninterested module will not be reduced";
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
// We set the simplification level to Aggressive to enable block merging.
GreedyRewriteConfig config;
config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
config.setUseTopDownTraversal(true);
// Populate canonicalization patterns for cf ops. When all targets of a
// 'cf.cond_br' or 'cf.switch' point to the same block, they will be
// canonicalized into a 'cf.br'.
auto context = region.getContext();
RewritePatternSet patterns(context);
cf::BranchOp::getCanonicalizationPatterns(patterns, context);
cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
FrozenRewritePatternSet fPatterns = std::move(patterns);
ReductionNode *smallestNode = nullptr;
mlir::OpBuilder b(context);
while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
size_t numSuccessor = branchTerminator->getNumSuccessors();
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
// Iterate through each successor of the branching terminator to try
// reducing the control flow to a single-path execution.
int branchIdx = -1;
for (int i = 0, e = numSuccessor; i < e; ++i) {
// We allocate memory on the heap because the object will be assigned to
// 'smallestNode'.
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
mlir::IRMapping mapper;
if (failed(root->initialize(module, region, mapper)))
llvm_unreachable("unexpected initialization failure");
Operation *tergetTerminator = mapper.lookup(branchTerminator);
Block *selectedBlock = tergetTerminator->getSuccessor(i);
auto branchOp = cast<BranchOpInterface>(tergetTerminator);
ValueRange selectedBlockOperands =
branchOp.getSuccessorForwardOperands(selectedBlock);
b.setInsertionPointAfter(tergetTerminator);
cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
selectedBlockOperands);
tergetTerminator->erase();
// Apply canonicalization patterns to collapse the now-redundant branches
(void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
config);
root->update(test.isInteresting(root->getModule()));
// Track the smallest "interesting" version of the IR found so far.
if (root->isInteresting() == Tester::Interestingness::True &&
(smallestNode == nullptr ||
root->getSize() < smallestNode->getSize())) {
smallestNode = root;
branchIdx = i;
}
}
// If an interesting reduced branch was found, commit the change to the
// original region and re-apply patterns for a final cleanup.
if (branchIdx != -1) {
Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
auto branchOp = cast<BranchOpInterface>(branchTerminator);
ValueRange selectedBlockOperands =
branchOp.getSuccessorForwardOperands(selectedBlock);
b.setInsertionPointAfter(branchTerminator);
cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
selectedBlockOperands);
branchTerminator->erase();
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
}
}
// If no branching terminators were found (skipping the while loop),
// there might still be opportunities for linear block merging or
// We apply patterns here as a final cleanup to ensure the region is fully
// simplified.
if (smallestNode == nullptr)
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
return success();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
@@ -308,8 +196,6 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
if (succeeded(eraseAllOpsInRegion(module, region, test)))
return success();
(void)eraseRedundantBlocksInRegion(module, region, test);
// In the second phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,

View File

@@ -58,68 +58,3 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func.func @simple5() {
return
}
// -----
// CHECK-LABEL: func @br_reduction
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
// -----
// CHECK-LABEL: func @br_reduction_loop
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
// select ^bb2
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
// select ^bb4
cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
^bb4:
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
// -----
// CHECK-LABEL: func @switch_reduction
// CHECK-SAME: %[[ARG0:.*]]: i32,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cf.switch %arg0 : i32, [
default: ^bb3(%arg1 : memref<2xf32>),
0: ^bb1,
1: ^bb2
]
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])