Revert "Reland "[mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface"" (#190727)
To decouple the BranchOpInterface implementation from the reduction-tree changes. Reverts llvm/llvm-project#189253,
This commit is contained in:
@@ -65,8 +65,7 @@ def AssertOp : CF_Op<"assert",
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BranchOp : CF_Op<"br", [
|
||||
DeclareOpInterfaceMethods<BranchOpInterface,
|
||||
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
Pure, Terminator
|
||||
]> {
|
||||
let summary = "Branch operation";
|
||||
@@ -115,8 +114,8 @@ def BranchOp : CF_Op<"br", [
|
||||
|
||||
def CondBranchOp
|
||||
: CF_Op<"cond_br", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface,
|
||||
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
|
||||
DeclareOpInterfaceMethods<
|
||||
BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
WeightedBranchOpInterface, Pure, Terminator]> {
|
||||
let summary = "Conditional branch operation";
|
||||
let description = [{
|
||||
@@ -242,8 +241,7 @@ def CondBranchOp
|
||||
|
||||
def SwitchOp : CF_Op<"switch",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface,
|
||||
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
Pure, Terminator]> {
|
||||
let summary = "Switch operation";
|
||||
let description = [{
|
||||
|
||||
@@ -98,15 +98,6 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
|
||||
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
|
||||
[{ return lhs == rhs; }]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
This method is called to returns the operands of this operation that
|
||||
are passed to the specified successor's block arguments. If the successor
|
||||
is not valid for this operation, or no operands are forwarded, an empty
|
||||
ValueRange is returned.
|
||||
}],
|
||||
"ValueRange", "getSuccessorForwardOperands",
|
||||
(ins "Block *":$successor), [{}],[{ return {};}]
|
||||
>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
|
||||
@@ -90,9 +90,6 @@ public:
|
||||
/// corresponding region.
|
||||
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
|
||||
|
||||
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
|
||||
IRMapping &mapper);
|
||||
|
||||
private:
|
||||
/// A custom BFS iterator. The difference between
|
||||
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
|
||||
|
||||
@@ -296,12 +296,6 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
||||
return getDest();
|
||||
}
|
||||
|
||||
ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
|
||||
if (successor == getDest())
|
||||
return getDestOperands();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -589,14 +583,6 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
|
||||
if (successor == getTrueDest())
|
||||
return getTrueOperands();
|
||||
else if (successor == getFalseDest())
|
||||
return getFalseOperands();
|
||||
return {};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1048,16 +1034,6 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
.add<SimplifyUniformBlockArguments>(context);
|
||||
}
|
||||
|
||||
ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
|
||||
if (successor == getDefaultDestination())
|
||||
return getDefaultOperands();
|
||||
SuccessorRange caseDests = getCaseDestinations();
|
||||
auto it = llvm::find(caseDests, successor);
|
||||
if (it == caseDests.end())
|
||||
return {};
|
||||
return getCaseOperands(std::distance(caseDests.begin(), it));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -9,7 +9,6 @@ add_mlir_library(MLIRReduce
|
||||
MLIRPass
|
||||
MLIRRewrite
|
||||
MLIRTransformUtils
|
||||
MLIRControlFlowDialect
|
||||
|
||||
DEPENDS
|
||||
MLIRReducerIncGen
|
||||
|
||||
@@ -45,16 +45,6 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ReductionNode::initialize(ModuleOp parentModule,
|
||||
Region &targetRegion,
|
||||
IRMapping &mapper) {
|
||||
module = cast<ModuleOp>(parentModule->clone(mapper));
|
||||
// Use the first block of targetRegion to locate the cloned region.
|
||||
Block *block = mapper.lookup(&*targetRegion.begin());
|
||||
region = block->getParent();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// If we haven't explored any variants from this node, we will create N
|
||||
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
|
||||
/// max element in `ranges` and create 2 new variants for each call.
|
||||
|
||||
@@ -14,10 +14,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/IR/DialectInterface.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Reducer/Passes.h"
|
||||
#include "mlir/Reducer/ReductionNode.h"
|
||||
#include "mlir/Reducer/ReductionPatternInterface.h"
|
||||
@@ -27,9 +24,6 @@
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
#include "llvm/Support/DebugLog.h"
|
||||
|
||||
#define DEBUG_TYPE "reduction-tree"
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_REDUCTIONTREEPASS
|
||||
@@ -190,112 +184,6 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion,
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Returns the first branching terminator (cond_br, switch, etc.) found in the
|
||||
// region.
|
||||
static Operation *getBranchTerminatorInRegion(Region ®ion) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
if (block.getNumSuccessors() > 1)
|
||||
return block.getTerminator();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Reduces the control flow in a region by iteratively forcing branching
|
||||
/// terminators to point to a single successor. It evaluates each potential
|
||||
/// branch path and commits the reduction that results in the smallest
|
||||
/// "interesting" module.
|
||||
static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
|
||||
Region ®ion,
|
||||
const Tester &test) {
|
||||
std::pair<Tester::Interestingness, size_t> initStatus =
|
||||
test.isInteresting(module);
|
||||
|
||||
// While exploring the reduction tree, we always branch from an interesting
|
||||
// node. Thus the root node must be interesting.
|
||||
if (initStatus.first != Tester::Interestingness::True)
|
||||
return module.emitWarning() << "uninterested module will not be reduced";
|
||||
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
|
||||
|
||||
// We set the simplification level to Aggressive to enable block merging.
|
||||
GreedyRewriteConfig config;
|
||||
config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
|
||||
config.setUseTopDownTraversal(true);
|
||||
|
||||
// Populate canonicalization patterns for cf ops. When all targets of a
|
||||
// 'cf.cond_br' or 'cf.switch' point to the same block, they will be
|
||||
// canonicalized into a 'cf.br'.
|
||||
auto context = region.getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
cf::BranchOp::getCanonicalizationPatterns(patterns, context);
|
||||
cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
|
||||
cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
|
||||
FrozenRewritePatternSet fPatterns = std::move(patterns);
|
||||
|
||||
ReductionNode *smallestNode = nullptr;
|
||||
mlir::OpBuilder b(context);
|
||||
while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
|
||||
size_t numSuccessor = branchTerminator->getNumSuccessors();
|
||||
std::vector<ReductionNode::Range> ranges{
|
||||
{0, std::distance(region.op_begin(), region.op_end())}};
|
||||
// Iterate through each successor of the branching terminator to try
|
||||
// reducing the control flow to a single-path execution.
|
||||
int branchIdx = -1;
|
||||
for (int i = 0, e = numSuccessor; i < e; ++i) {
|
||||
// We allocate memory on the heap because the object will be assigned to
|
||||
// 'smallestNode'.
|
||||
ReductionNode *root = allocator.Allocate();
|
||||
new (root) ReductionNode(nullptr, ranges, allocator);
|
||||
mlir::IRMapping mapper;
|
||||
if (failed(root->initialize(module, region, mapper)))
|
||||
llvm_unreachable("unexpected initialization failure");
|
||||
Operation *tergetTerminator = mapper.lookup(branchTerminator);
|
||||
Block *selectedBlock = tergetTerminator->getSuccessor(i);
|
||||
auto branchOp = cast<BranchOpInterface>(tergetTerminator);
|
||||
ValueRange selectedBlockOperands =
|
||||
branchOp.getSuccessorForwardOperands(selectedBlock);
|
||||
b.setInsertionPointAfter(tergetTerminator);
|
||||
cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
|
||||
selectedBlockOperands);
|
||||
tergetTerminator->erase();
|
||||
|
||||
// Apply canonicalization patterns to collapse the now-redundant branches
|
||||
(void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
|
||||
config);
|
||||
root->update(test.isInteresting(root->getModule()));
|
||||
|
||||
// Track the smallest "interesting" version of the IR found so far.
|
||||
if (root->isInteresting() == Tester::Interestingness::True &&
|
||||
(smallestNode == nullptr ||
|
||||
root->getSize() < smallestNode->getSize())) {
|
||||
smallestNode = root;
|
||||
branchIdx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// If an interesting reduced branch was found, commit the change to the
|
||||
// original region and re-apply patterns for a final cleanup.
|
||||
if (branchIdx != -1) {
|
||||
Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
|
||||
auto branchOp = cast<BranchOpInterface>(branchTerminator);
|
||||
ValueRange selectedBlockOperands =
|
||||
branchOp.getSuccessorForwardOperands(selectedBlock);
|
||||
b.setInsertionPointAfter(branchTerminator);
|
||||
cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
|
||||
selectedBlockOperands);
|
||||
branchTerminator->erase();
|
||||
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
|
||||
}
|
||||
}
|
||||
|
||||
// If no branching terminators were found (skipping the while loop),
|
||||
// there might still be opportunities for linear block merging or
|
||||
// We apply patterns here as a final cleanup to ensure the region is fully
|
||||
// simplified.
|
||||
if (smallestNode == nullptr)
|
||||
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename IteratorType>
|
||||
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
@@ -308,8 +196,6 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
|
||||
if (succeeded(eraseAllOpsInRegion(module, region, test)))
|
||||
return success();
|
||||
|
||||
(void)eraseRedundantBlocksInRegion(module, region, test);
|
||||
|
||||
// In the second phase, we don't apply any patterns so that we only select the
|
||||
// range of operations to keep to the module stay interesting.
|
||||
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
|
||||
|
||||
@@ -58,68 +58,3 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
func.func @simple5() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @br_reduction
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
|
||||
func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @br_reduction_loop
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
|
||||
func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
// select ^bb2
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
// select ^bb4
|
||||
cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
|
||||
^bb4:
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @switch_reduction
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
|
||||
func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
|
||||
cf.switch %arg0 : i32, [
|
||||
default: ^bb3(%arg1 : memref<2xf32>),
|
||||
0: ^bb1,
|
||||
1: ^bb2
|
||||
]
|
||||
^bb1:
|
||||
cf.br ^bb3(%arg1 : memref<2xf32>)
|
||||
^bb2:
|
||||
%0 = memref.alloc() : memref<2xf32>
|
||||
cf.br ^bb3(%0 : memref<2xf32>)
|
||||
^bb3(%1: memref<2xf32>):
|
||||
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
|
||||
|
||||
Reference in New Issue
Block a user