//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/Support/DebugLog.h" using namespace mlir; //===----------------------------------------------------------------------===// // ControlFlowInterfaces //===----------------------------------------------------------------------===// #include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc" SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands) : producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) { } SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, MutableOperandRange forwardedOperands) : producedOperandCount(producedOperandCount), forwardedOperands(std::move(forwardedOperands)) {} //===----------------------------------------------------------------------===// // BranchOpInterface //===----------------------------------------------------------------------===// /// Returns the `BlockArgument` corresponding to operand `operandIndex` in some /// successor if 'operandIndex' is within the range of 'operands', or /// std::nullopt if `operandIndex` isn't a successor operand index. std::optional detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { LDBG() << "Getting branch successor argument for operand index " << operandIndex << " in successor block"; OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. if (forwardedOperands.empty()) { LDBG() << "No forwarded operands, returning nullopt"; return std::nullopt; } // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || operandIndex >= (operandsStart + forwardedOperands.size())) { LDBG() << "Operand index " << operandIndex << " out of range [" << operandsStart << ", " << (operandsStart + forwardedOperands.size()) << "), returning nullopt"; return std::nullopt; } // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; LDBG() << "Computed argument index " << argIndex << " for successor block"; return successor->getArgument(argIndex); } /// Verify that the given operands match those of the given successor block. LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { LDBG() << "Verifying branch successor operands for successor #" << succNo << " in operation " << op->getName(); // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); LDBG() << "Branch has " << operandCount << " operands, target block has " << destBB->getNumArguments() << " arguments"; if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo << ", but target block has " << destBB->getNumArguments(); // Check the types. LDBG() << "Checking type compatibility for " << (operandCount - operands.getProducedOperandCount()) << " forwarded operands"; for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { Type operandType = operands[i].getType(); Type argType = destBB->getArgument(i).getType(); LDBG() << "Checking type compatibility: operand type " << operandType << " vs argument type " << argType; if (!cast(op).areTypesCompatible(operandType, argType)) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } LDBG() << "Branch successor operand verification successful"; return success(); } //===----------------------------------------------------------------------===// // WeightedBranchOpInterface //===----------------------------------------------------------------------===// static LogicalResult verifyWeights(Operation *op, llvm::ArrayRef weights, std::size_t expectedWeightsNum, llvm::StringRef weightAnchorName, llvm::StringRef weightRefName) { if (weights.empty()) return success(); if (weights.size() != expectedWeightsNum) return op->emitError() << "expects number of " << weightAnchorName << " weights to match number of " << weightRefName << ": " << weights.size() << " vs " << expectedWeightsNum; if (llvm::all_of(weights, [](int32_t value) { return value == 0; })) return op->emitError() << "branch weights cannot all be zero"; return success(); } LogicalResult detail::verifyBranchWeights(Operation *op) { llvm::ArrayRef weights = cast(op).getWeights(); return verifyWeights(op, weights, op->getNumSuccessors(), "branch", "successors"); } //===----------------------------------------------------------------------===// // WeightedRegionBranchOpInterface //===----------------------------------------------------------------------===// LogicalResult detail::verifyRegionBranchWeights(Operation *op) { llvm::ArrayRef weights = cast(op).getWeights(); return verifyWeights(op, weights, op->getNumRegions(), "region", "regions"); } //===----------------------------------------------------------------------===// // RegionBranchOpInterface //===----------------------------------------------------------------------===// /// Verify that types match along control flow edges described the given op. LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) { auto regionInterface = cast(op); // Verify all control flow edges from region branch points to region // successors. SmallVector regionBranchPoints = regionInterface.getAllRegionBranchPoints(); for (const RegionBranchPoint &branchPoint : regionBranchPoints) { SmallVector successors; regionInterface.getSuccessorRegions(branchPoint, successors); for (const RegionSuccessor &successor : successors) { // Helper function that print the region branch point and the region // successor. auto emitRegionEdgeError = [&]() { InFlightDiagnostic diag = regionInterface->emitOpError("along control flow edge from "); if (branchPoint.isParent()) { diag << "parent"; diag.attachNote(op->getLoc()) << "region branch point"; } else { diag << "Operation " << branchPoint.getTerminatorPredecessorOrNull()->getName(); diag.attachNote( branchPoint.getTerminatorPredecessorOrNull()->getLoc()) << "region branch point"; } diag << " to "; if (Region *region = successor.getSuccessor()) { diag << "Region #" << region->getRegionNumber(); } else { diag << "parent"; } return diag; }; // Verify number of successor operands and successor inputs. OperandRange succOperands = regionInterface.getSuccessorOperands(branchPoint, successor); ValueRange succInputs = regionInterface.getSuccessorInputs(successor); if (succOperands.size() != succInputs.size()) { return emitRegionEdgeError() << ": region branch point has " << succOperands.size() << " operands, but region successor needs " << succInputs.size() << " inputs"; } // Verify that the types are compatible. TypeRange succInputTypes = succInputs.getTypes(); TypeRange succOperandTypes = succOperands.getTypes(); for (const auto &typesIdx : llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) { Type succOperandType = std::get<0>(typesIdx.value()); Type succInputType = std::get<1>(typesIdx.value()); if (!regionInterface.areTypesCompatible(succOperandType, succInputType)) return emitRegionEdgeError() << ": successor operand type #" << typesIdx.index() << " " << succOperandType << " should match successor input type #" << typesIdx.index() << " " << succInputType; } } } return success(); } /// Stop condition for `traverseRegionGraph`. The traversal is interrupted if /// this function returns "true" for a successor region. The first parameter is /// the successor region. The second parameter indicates all already visited /// regions. using StopConditionFn = function_ref visited)>; /// Traverse the region graph starting at `begin`. The traversal is interrupted /// if `stopCondition` evaluates to "true" for a successor region. In that case, /// this function returns "true". Otherwise, if the traversal was not /// interrupted, this function returns "false". static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn) { auto op = cast(begin->getParentOp()); LDBG() << "Starting region graph traversal from region #" << begin->getRegionNumber() << " in operation " << op->getName(); SmallVector visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; LDBG() << "Initialized visited array with " << op->getNumRegions() << " regions"; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector worklist; auto enqueueAllSuccessors = [&](Region *region) { LDBG() << "Enqueuing successors for region #" << region->getRegionNumber(); SmallVector operandAttributes(op->getNumOperands()); for (Block &block : *region) { if (block.empty()) continue; auto terminator = dyn_cast(block.back()); if (!terminator) continue; SmallVector successors; operandAttributes.resize(terminator->getNumOperands()); terminator.getSuccessorRegions(operandAttributes, successors); LDBG() << "Found " << successors.size() << " successors from terminator in block"; for (RegionSuccessor successor : successors) { if (!successor.isParent()) { worklist.push_back(successor.getSuccessor()); LDBG() << "Added region #" << successor.getSuccessor()->getRegionNumber() << " to worklist"; } else { LDBG() << "Skipping parent successor"; } } } }; enqueueAllSuccessors(begin); LDBG() << "Initial worklist size: " << worklist.size(); // Process all regions in the worklist via DFS. while (!worklist.empty()) { Region *nextRegion = worklist.pop_back_val(); LDBG() << "Processing region #" << nextRegion->getRegionNumber() << " from worklist (remaining: " << worklist.size() << ")"; if (stopConditionFn(nextRegion, visited)) { LDBG() << "Stop condition met for region #" << nextRegion->getRegionNumber() << ", returning true"; return true; } if (!nextRegion->getParentOp()) { llvm::errs() << "Region " << *nextRegion << " has no parent op\n"; return false; } if (visited[nextRegion->getRegionNumber()]) { LDBG() << "Region #" << nextRegion->getRegionNumber() << " already visited, skipping"; continue; } visited[nextRegion->getRegionNumber()] = true; LDBG() << "Marking region #" << nextRegion->getRegionNumber() << " as visited"; enqueueAllSuccessors(nextRegion); } LDBG() << "Traversal completed, returning false"; return false; } /// Return `true` if region `r` is reachable from region `begin` according to /// the RegionBranchOpInterface (by taking a branch). static bool isRegionReachable(Region *begin, Region *r) { assert(begin->getParentOp() == r->getParentOp() && "expected that both regions belong to the same op"); return traverseRegionGraph(begin, [&](Region *nextRegion, ArrayRef visited) { // Interrupt traversal if `r` was reached. return nextRegion == r; }); } /// Return `true` if `a` and `b` are in mutually exclusive regions. /// /// 1. Find the first common of `a` and `b` (ancestor) that implements /// RegionBranchOpInterface. /// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are /// contained. /// 3. Check if `regionA` and `regionB` are mutually exclusive. They are /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { LDBG() << "Checking if operations are in mutually exclusive regions: " << a->getName() << " and " << b->getName(); assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType(); while (branchOp) { LDBG() << "Checking branch operation " << branchOp->getName(); // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { LDBG() << "Operation b is not inside branchOp, checking next ancestor"; // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType(); continue; } LDBG() << "Both operations are inside branchOp, finding their regions"; // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; for (Region &r : branchOp->getRegions()) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; LDBG() << "Found region #" << r.getRegionNumber() << " for operation a"; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; LDBG() << "Found region #" << r.getRegionNumber() << " for operation b"; } } assert(regionA && regionB && "could not find region of op"); LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #" << regionB->getRegionNumber(); // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. bool regionsAreDistinct = (regionA != regionB); bool aNotReachableFromB = !isRegionReachable(regionA, regionB); bool bNotReachableFromA = !isRegionReachable(regionB, regionA); LDBG() << "Regions distinct: " << regionsAreDistinct << ", A not reachable from B: " << aNotReachableFromB << ", B not reachable from A: " << bNotReachableFromA; bool mutuallyExclusive = regionsAreDistinct && aNotReachableFromB && bNotReachableFromA; LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive; return mutuallyExclusive; } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. LDBG() << "No common RegionBranchOpInterface found, operations are not " "mutually exclusive"; return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { LDBG() << "Checking if region #" << index << " is repetitive in operation " << getOperation()->getName(); Region *region = &getOperation()->getRegion(index); bool isRepetitive = isRegionReachable(region, region); LDBG() << "Region #" << index << " is repetitive: " << isRepetitive; return isRepetitive; } bool RegionBranchOpInterface::hasLoop() { LDBG() << "Checking if operation " << getOperation()->getName() << " has loops"; SmallVector entryRegions; getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); LDBG() << "Found " << entryRegions.size() << " entry regions"; for (RegionSuccessor successor : entryRegions) { if (!successor.isParent()) { LDBG() << "Checking entry region #" << successor.getSuccessor()->getRegionNumber() << " for loops"; bool hasLoop = traverseRegionGraph(successor.getSuccessor(), [](Region *nextRegion, ArrayRef visited) { // Interrupt traversal if the region was already // visited. return visited[nextRegion->getRegionNumber()]; }); if (hasLoop) { LDBG() << "Found loop in entry region #" << successor.getSuccessor()->getRegionNumber(); return true; } } else { LDBG() << "Skipping parent successor"; } } LDBG() << "No loops found in operation"; return false; } OperandRange RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src, RegionSuccessor dest) { if (src.isParent()) return getEntrySuccessorOperands(dest); return src.getTerminatorPredecessorOrNull().getSuccessorOperands(dest); } SmallVector RegionBranchOpInterface::getNonSuccessorInputs(RegionSuccessor successor) { SmallVector results = llvm::to_vector( successor.isParent() ? ValueRange(getOperation()->getResults()) : ValueRange(successor.getSuccessor()->getArguments())); ValueRange successorInputs = getSuccessorInputs(successor); if (!successorInputs.empty()) { unsigned inputBegin = successor.isParent() ? cast(successorInputs.front()).getResultNumber() : cast(successorInputs.front()).getArgNumber(); results.erase(results.begin() + inputBegin, results.begin() + inputBegin + successorInputs.size()); } return results; } static MutableArrayRef operandsToOpOperands(OperandRange &operands) { return MutableArrayRef(operands.getBase(), operands.size()); } static void getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp, RegionBranchSuccessorMapping &mapping, RegionBranchPoint src) { SmallVector successors; branchOp.getSuccessorRegions(src, successors); for (RegionSuccessor dst : successors) { OperandRange operands = branchOp.getSuccessorOperands(src, dst); assert(operands.size() == branchOp.getSuccessorInputs(dst).size() && "expected the same number of operands and inputs"); for (const auto &[operand, input] : llvm::zip_equal( operandsToOpOperands(operands), branchOp.getSuccessorInputs(dst))) mapping[&operand].push_back(input); } } void RegionBranchOpInterface::getSuccessorOperandInputMapping( RegionBranchSuccessorMapping &mapping, std::optional src) { if (src.has_value()) { ::getSuccessorOperandInputMapping(*this, mapping, src.value()); } else { // No region branch point specified: populate the mapping for all possible // region branch points. for (RegionBranchPoint branchPoint : getAllRegionBranchPoints()) ::getSuccessorOperandInputMapping(*this, mapping, branchPoint); } } static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping( const RegionBranchSuccessorMapping &operandToInputs) { RegionBranchInverseSuccessorMapping inputToOperands; for (const auto &[operand, inputs] : operandToInputs) { for (Value input : inputs) inputToOperands[input].push_back(operand); } return inputToOperands; } void RegionBranchOpInterface::getSuccessorInputOperandMapping( RegionBranchInverseSuccessorMapping &mapping) { RegionBranchSuccessorMapping operandToInputs; getSuccessorOperandInputMapping(operandToInputs); mapping = invertRegionBranchSuccessorMapping(operandToInputs); } SmallVector RegionBranchOpInterface::getAllRegionBranchPoints() { SmallVector branchPoints; branchPoints.push_back(RegionBranchPoint::parent()); for (Region ®ion : getOperation()->getRegions()) { for (Block &block : region) { if (block.empty()) continue; if (auto terminator = dyn_cast(block.back())) branchPoints.push_back(RegionBranchPoint(terminator)); } } return branchPoints; } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { LDBG() << "Finding enclosing repetitive region for operation " << op->getName(); while (Region *region = op->getParentRegion()) { LDBG() << "Checking region #" << region->getRegionNumber() << " in operation " << region->getParentOp()->getName(); op = region->getParentOp(); if (auto branchOp = dyn_cast(op)) { LDBG() << "Found RegionBranchOpInterface, checking if region is repetitive"; if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; } } else { LDBG() << "Parent operation does not implement RegionBranchOpInterface"; } } LDBG() << "No enclosing repetitive region found"; return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { LDBG() << "Finding enclosing repetitive region for value"; Region *region = value.getParentRegion(); while (region) { LDBG() << "Checking region #" << region->getRegionNumber() << " in operation " << region->getParentOp()->getName(); Operation *op = region->getParentOp(); if (auto branchOp = dyn_cast(op)) { LDBG() << "Found RegionBranchOpInterface, checking if region is repetitive"; if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; } } else { LDBG() << "Parent operation does not implement RegionBranchOpInterface"; } region = op->getParentRegion(); } LDBG() << "No enclosing repetitive region found for value"; return nullptr; } /// Return "true" if `a` can be used in lieu of `b`, where `b` is a region /// successor input and `a` is a "reachable value" of `b`. Reachable values /// are successor operand values that are (maybe transitively) forwarded to /// `b`. static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) { assert((b.getDefiningOp() == regionBranchOp || b.getParentRegion()->getParentOp() == regionBranchOp) && "b must be a region successor input"); // Case 1: `a` is defined inside of the region branch op. `a` must be // directly nested in the region branch op. Otherwise, it could not have // been among the reachable values for a region successor input. if (a.getParentRegion()->getParentOp() == regionBranchOp) { // Case 1.1: If `b` is a result of the region branch op, `a` is not in // scope for `b`. // Example: // %b = region_op({ // ^bb0(%a1: ...): // %a2 = ... // }) if (isa(b)) return false; // Case 1.2: `b` is an entry block argument of a region. `a` is in scope // for `b` only if it is also an entry block argument of the same region. // Example: // region_op({ // ^bb0(%b: ..., %a: ...): // ... // }) assert(isa(b) && "b must be a block argument"); return isa(a) && cast(a).getOwner() == cast(b).getOwner(); } // Case 2: `a` is defined outside of the region branch op. In that case, we // can safely assume that `a` was defined before `b`. Otherwise, it could not // be among the reachable values for a region successor input. // Example: // { <- %a1 parent region begins here. // ^bb0(%a1: ...): // %a2 = ... // %b1 = reigon_op({ // ^bb1(%b2: ...): // ... // }) // } return true; } /// Compute all non-successor-input values that a successor input could have /// based on the given successor input to successor operand mapping. /// /// Starting with the given value, trace back all predecessor values (i.e., /// preceding successor operands) and add them to the set of reachable values. /// If the successor operand is again a successor input, do not add it to the /// result set, but instead continue the traversal. /// /// If `maxReachableValues` is set, the traversal is aborted early and /// `failure` is returned as soon as the number of reachable values exceeds /// the limit. Otherwise, `success` is returned and the result set contains /// all reachable values. /// /// Example 1: /// %r = scf.if ... { /// scf.yield %a : ... /// } else { /// scf.yield %b : ... /// } /// reachableValues(%r) = {%a, %b} /// /// Example 2: /// %r = scf.for ... iter_args(%arg0 = %0) -> ... { /// scf.yield %arg0 : ... /// } /// reachableValues(%arg0) = {%0} /// reachableValues(%r) = {%0} /// /// Example 3: /// %r = scf.for ... iter_args(%arg0 = %0) -> ... { /// ... /// scf.yield %1 : ... /// } /// reachableValues(%arg0) = {%0, %1} /// reachableValues(%r) = {%0, %1} static LogicalResult computeReachableValuesFromSuccessorInput( llvm::SmallDenseSet &result, Value value, const RegionBranchInverseSuccessorMapping &inputToOperands, std::optional maxReachableValues = std::nullopt) { assert(inputToOperands.contains(value) && "value must be a successor input"); llvm::SmallDenseSet visited; SmallVector worklist; worklist.push_back(value); while (!worklist.empty()) { Value next = worklist.pop_back_val(); auto it = inputToOperands.find(next); if (it == inputToOperands.end()) { result.insert(next); if (maxReachableValues && result.size() > *maxReachableValues) return failure(); continue; } for (OpOperand *operand : it->second) if (visited.insert(operand->get()).second) worklist.push_back(operand->get()); } // Note: The result does not contain any successor inputs. (Therefore, // `value` is also guaranteed to be excluded.) return success(); } namespace { /// Try to make successor inputs dead by replacing their uses with values that /// are not successor inputs. This pattern enables additional canonicalization /// opportunities for RemoveDeadRegionBranchOpSuccessorInputs. /// /// Example: /// /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { /// scf.yield %arg1, %arg1 : ... /// } /// use(%r0, %r1) /// /// reachableValues(%r0) = {%0, %1} /// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1. /// reachableValues(%arg0) = {%0, %1} /// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1. /// /// IR after pattern application: /// /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { /// scf.yield %1, %1 : ... /// } /// use(%r0, %1) /// /// Note that %r1 and %arg1 are dead now. The IR can now be further /// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs. struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern { MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name, PatternBenefit benefit = 1) : RewritePattern(name, benefit, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { assert(!op->hasTrait() && "isolated-from-above ops are not supported"); // Compute the mapping of successor inputs to successor operands. auto regionBranchOp = cast(op); RegionBranchInverseSuccessorMapping inputToOperands; regionBranchOp.getSuccessorInputOperandMapping(inputToOperands); // Try to replace the uses of each successor input one-by-one. bool changed = false; for (Value value : inputToOperands.keys()) { // Nothing to do for successor inputs that are already dead. if (value.use_empty()) continue; // Nothing to do for successor inputs that may have multiple reachable // values. llvm::SmallDenseSet reachableValues; if (failed(computeReachableValuesFromSuccessorInput( reachableValues, value, inputToOperands, /*maxReachableValues=*/1)) || reachableValues.empty()) continue; assert(*reachableValues.begin() != value && "successor inputs are supposed to be excluded"); // Do not replace `value` with the found reachable value if doing so // would violate dominance. Example: // %r = scf.execute_region ... { // %a = ... // scf.yield %a : ... // } // use(%r) // In the above example, reachableValues(%r) = {%a}, but %a cannot be // used as a replacement for %r due to dominance / scope. if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value)) continue; rewriter.replaceAllUsesWith(value, *reachableValues.begin()); changed = true; } return success(changed); } }; /// Lookup a bit vector in the given mapping (DenseMap). If the key was not /// found, create a new bit vector with the given size and initialize it with /// false. template static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key, unsigned size) { return mapping.try_emplace(key, size, false).first->second; } /// Compute tied successor inputs. Tied successor inputs are successor inputs /// that come as a set. If you erase one value from a set, you must erase all /// values from the set. Otherwise, the op would become structurally invalid. /// Each successor input appears in exactly one set. /// /// Example: /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { /// ... /// } /// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. static llvm::EquivalenceClasses computeTiedSuccessorInputs( const RegionBranchSuccessorMapping &operandToInputs) { llvm::EquivalenceClasses tiedSuccessorInputs; for (const auto &[operand, inputs] : operandToInputs) { assert(!inputs.empty() && "expected non-empty inputs"); Value firstInput = inputs.front(); tiedSuccessorInputs.insert(firstInput); for (Value nextInput : llvm::drop_begin(inputs)) { // As we explore more successor operand to successor input mappings, // existing sets may get merged. tiedSuccessorInputs.unionSets(firstInput, nextInput); } } return tiedSuccessorInputs; } /// Remove dead successor inputs from region branch ops. A successor input is /// dead if it has no uses. Successor inputs come in sets of tied values: if /// you remove one value from a set, you must remove all values from the set. /// Furthermore, successor operands must also be removed. (Op operands are not /// part of the set, but the set is built based on the successor operand to /// successor input mapping.) /// /// Example 1: /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... { /// scf.yield %0, %arg1 : ... /// } /// use(%0, %1) /// /// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first /// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The /// resulting IR is as follows: /// /// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... { /// scf.yield %arg1 : ... /// } /// use(%0, %1) /// /// Example 2: /// %r0, %r1 = scf.while (%arg0 = %0) { /// scf.condition(...) %arg0, %arg0 : ... /// } do { /// ^bb0(%arg1: ..., %arg2: ...): /// scf.yield %arg1 : ... /// } /// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}. /// /// Example 3: /// %r1, %r2 = scf.if ... { /// scf.yield %0, %1 : ... /// } else { /// scf.yield %2, %3 : ... /// } /// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each /// value can be removed independently of the other values. struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern { RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name, PatternBenefit benefit = 1) : RewritePattern(name, benefit, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { assert(!op->hasTrait() && "isolated-from-above ops are not supported"); // Compute tied values: values that must come as a set. If you remove one, // you must remove all. If a successor op operand is forwarded to two // successor inputs %a and %b, both %a and %b are in the same set. auto regionBranchOp = cast(op); RegionBranchSuccessorMapping operandToInputs; regionBranchOp.getSuccessorOperandInputMapping(operandToInputs); llvm::EquivalenceClasses tiedSuccessorInputs = computeTiedSuccessorInputs(operandToInputs); // Determine which values to remove and group them by block and operation. SmallVector valuesToRemove; DenseMap blockArgsToRemove; BitVector resultsToRemove(regionBranchOp->getNumResults(), false); // Iterate over all sets of tied successor inputs. for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end(); it != e; ++it) { if (!(*it)->isLeader()) continue; // Value can be removed if it is dead and all other tied values are also // dead. bool allDead = true; for (auto memberIt = tiedSuccessorInputs.member_begin(**it); memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { // Iterate over all values in the set and check their liveness. if (!memberIt->use_empty()) { allDead = false; break; } } if (!allDead) continue; // The entire set is dead. Group values by block and operation to // simplify removal. for (auto memberIt = tiedSuccessorInputs.member_begin(**it); memberIt != tiedSuccessorInputs.member_end(); ++memberIt) { if (auto arg = dyn_cast(*memberIt)) { // Set blockArgsToRemove[block][arg_number] = true. BitVector &vector = lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(), arg.getOwner()->getNumArguments()); vector.set(arg.getArgNumber()); } else { // Set resultsToRemove[result_number] = true. OpResult result = cast(*memberIt); assert(result.getDefiningOp() == regionBranchOp && "result must be a region branch op result"); resultsToRemove.set(result.getResultNumber()); } valuesToRemove.push_back(*memberIt); } } if (valuesToRemove.empty()) return rewriter.notifyMatchFailure(op, "no values to remove"); // Find operands that must be removed together with the values. RegionBranchInverseSuccessorMapping inputsToOperands = invertRegionBranchSuccessorMapping(operandToInputs); DenseMap operandsToRemove; for (Value value : valuesToRemove) { for (OpOperand *operand : inputsToOperands[value]) { // Set operandsToRemove[op][operand_number] = true. BitVector &vector = lookupOrCreateBitVector(operandsToRemove, operand->getOwner(), operand->getOwner()->getNumOperands()); vector.set(operand->getOperandNumber()); } } // Erase operands. for (auto &pair : operandsToRemove) { Operation *op = pair.first; BitVector &operands = pair.second; rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); }); } // Erase block arguments. for (auto &pair : blockArgsToRemove) { Block *block = pair.first; BitVector &blockArg = pair.second; rewriter.modifyOpInPlace(block->getParentOp(), [&]() { block->eraseArguments(blockArg); }); } // Erase op results. if (resultsToRemove.any()) rewriter.eraseOpResults(regionBranchOp, resultsToRemove); return success(); } }; /// Return the "owner" of a value: the parent block for block arguments, the /// defining op for op results. static void *getOwnerOfValue(Value value) { if (auto arg = dyn_cast(value)) return arg.getOwner(); return value.getDefiningOp(); } /// Get the block argument or op result number of the given value. static unsigned getArgOrResultNumber(Value value) { if (auto opResult = llvm::dyn_cast(value)) return opResult.getResultNumber(); return llvm::cast(value).getArgNumber(); } /// Find duplicate successor inputs and make all dead except for one. Two /// successor inputs are "duplicate" if their corresponding successor operands /// have the same values. This pattern enables additional canonicalization /// opportunities for RemoveDeadRegionBranchOpSuccessorInputs. /// /// Example: /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... { /// use(%arg0, %arg1) /// ... /// scf.yield %x, %x : ... /// } /// use(%r0, %r1) /// /// Operands of successor input %r0: [%0, %x] /// Operands of successor input %r1: [%0, %x] ==> DUPLICATE! /// Replace %r1 with %r0. /// /// Operands of successor input %arg0: [%0, %x] /// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE! /// Replace %arg1 with %arg0. (We have to make sure that we make same decision /// as for the other tied successor inputs above. Otherwise, a set of tied /// successor inputs may not become entirely dead.) /// /// The resulting IR is as follows: /// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... { /// use(%arg0, %arg0) /// ... /// scf.yield %x, %x : ... /// } /// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct, /// // but does not help with further canonicalizations. struct RemoveDuplicateSuccessorInputUses : public RewritePattern { RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name, PatternBenefit benefit = 1) : RewritePattern(name, benefit, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { assert(!op->hasTrait() && "isolated-from-above ops are not supported"); // Collect all successor inputs and sort them. When dropping the uses of a // successor input, we'd like to also drop the uses of the same tied // successor inputs. Otherwise, a set of tied successor inputs may not // become entirely dead, which is required for // RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them. // (Sorting is not required for correctness.) auto regionBranchOp = cast(op); RegionBranchInverseSuccessorMapping inputsToOperands; regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands); SmallVector inputs = llvm::to_vector(inputsToOperands.keys()); llvm::sort(inputs, [](Value a, Value b) { return getArgOrResultNumber(a) < getArgOrResultNumber(b); }); // Group inputs by their operand "signature" to find duplicates. Two // successor inputs are duplicates if each predecessor (region branch point) // forwards the same value for both. Let n = number of successor inputs and // k = number of predecessors per input. Instead of comparing every pair of // inputs (O(n² * k)), we build a signature for each input and group them // via a std::map. // // A signature is a sorted list of (predecessor, forwarded value) pairs. // Within each group, all but the first (canonical) input are replaced with // the canonical one. using SigEntry = std::pair; using Signature = SmallVector; auto sigEntryLess = [](const SigEntry &a, const SigEntry &b) { if (a.first != b.first) return a.first < b.first; return a.second.getAsOpaquePointer() < b.second.getAsOpaquePointer(); }; // The map key is (signature, owner). Two inputs are duplicates only if they // have the same signature AND the same owner (block or defining op). This // ensures we track one canonical per owner group. using MapKey = std::pair; auto mapKeyLess = [&](const MapKey &a, const MapKey &b) { if (a.second != b.second) return a.second < b.second; return std::lexicographical_compare(a.first.begin(), a.first.end(), b.first.begin(), b.first.end(), sigEntryLess); }; std::map signatureToCanonical( mapKeyLess); bool changed = false; // Total complexity: O(n * k * max(log k, log n)). For each input, sorting // the signature costs O(k log k) and the std::map lookup costs O(k log n). for (Value input : inputs) { // Gather the predecessor value for each predecessor (region branch // point) and sort them to form this input's signature. Signature sig; for (OpOperand *operand : inputsToOperands[input]) sig.emplace_back(operand->getOwner(), operand->get()); llvm::sort(sig, sigEntryLess); void *owner = getOwnerOfValue(input); auto [it, inserted] = signatureToCanonical.try_emplace( MapKey{std::move(sig), owner}, input); if (!inserted) { Value canonical = it->second; // Nothing to do if input is already dead. if (input.use_empty()) continue; rewriter.replaceAllUsesWith(input, canonical); changed = true; } } return success(changed); } }; /// Given a range of values, return a vector of attributes of the same size, /// where the i-th attribute is the constant value of the i-th value. If a /// value is not constant, the corresponding attribute is null. static SmallVector extractConstants(ValueRange values) { return llvm::map_to_vector(values, [](Value value) { Attribute attr; matchPattern(value, m_Constant(&attr)); return attr; }); } /// Return all successor regions when branching from the given region branch /// point. This helper functions extracts all constant operand values and /// passes them to the `RegionBranchOpInterface`. static SmallVector getSuccessorRegionsWithAttrs(RegionBranchOpInterface op, RegionBranchPoint point) { SmallVector successors; if (point.isParent()) { op.getEntrySuccessorRegions(extractConstants(op->getOperands()), successors); return successors; } RegionBranchTerminatorOpInterface terminator = point.getTerminatorPredecessorOrNull(); terminator.getSuccessorRegions(extractConstants(terminator->getOperands()), successors); return successors; } /// Find the single acyclic path through the given region branch op. Return an /// empty vector if no such path or multiple such paths exist. /// /// Example: "scf.if %true" has a single path: parent => then_region => parent /// /// Example: "scf.if ???" has multiple paths: /// (1) parent => then_region => parent /// (2) parent => else_region => parent /// /// Example: "scf.while with scf.condition(%false)" has a single path: /// parent => before_region => parent /// /// Example: "scf.for with 0 iterations" has a single path: parent => parent /// /// Note: Each path starts and ends with "parent". The "parent" at the beginning /// of the path is omitted from the result. /// /// Note: This function also returns an "empty" path when a region with multiple /// blocks was found. static SmallVector computeSingleAcyclicRegionBranchPath(RegionBranchOpInterface op) { llvm::SmallDenseSet visited; SmallVector path; // Path starts with "parent". RegionBranchPoint next = RegionBranchPoint::parent(); do { SmallVector successors = getSuccessorRegionsWithAttrs(op, next); if (successors.size() != 1) { // There are multiple region successors. I.e., there are multiple paths // through the region branch op. return {}; } path.push_back(successors.front()); if (successors.front().isParent()) { // Found path that ends with "parent". return path; } Region *region = successors.front().getSuccessor(); if (!region->hasOneBlock()) { // Entering a region with multiple blocks. Such regions are not supported // at the moment. return {}; } if (!visited.insert(region).second) { // We have already visited this region. I.e., we have found a cycle. return {}; } auto terminator = dyn_cast(®ion->front().back()); if (!terminator) { // Region has no RegionBranchTerminatorOpInterface terminator. E.g., the // terminator could be a "ub.unreachable" op. Such IR is not supported. return {}; } next = RegionBranchPoint(terminator); } while (true); llvm_unreachable("expected to return from loop"); } /// Inline the body of the matched region branch op into the enclosing block if /// there is exactly one acyclic path through the region branch op, starting /// from "parent", and if that path ends with "parent". /// /// Example: This pattern can inline "scf.for" operations that are guaranteed to /// have a single iteration, as indicated by the region branch path "parent => /// region => parent". "scf.for" operations have a non-successor-input: the loop /// induction variable. Non-successor-input values have op-specific semantics /// and cannot be reasoned about through the `RegionBranchOpInterface`. A /// replacement value for non-successor-inputs is injected by the user-specified /// lambda: in the case of the loop induction variable of an "scf.for", the /// lower bound of the loop is used as a replacement value. /// /// Before pattern application: /// %r = scf.for %iv = %c5 to %c6 step %c1 iter_args(%arg0 = %0) { /// %1 = "producer"(%arg0, %iv) /// scf.yield %1 /// } /// "user"(%r) /// /// After pattern application: /// %1 = "producer"(%0, %c5) /// "user"(%1) /// /// This pattern is limited to the following cases: /// - Only regions with a single block are supported. This could be generalized. /// - Region branch ops with side effects are not supported. (Recursive side /// effects are fine.) /// /// Note: This pattern queries the region dataflow from the /// `RegionBranchOpInterface`. Replacement values are for block arguments / op /// results are determined based on region dataflow. In case of /// non-successor-inputs (whose values are not modeled by the /// `RegionBranchOpInterface`), a user-specified lambda is queried. struct InlineRegionBranchOp : public RewritePattern { InlineRegionBranchOp(MLIRContext *context, StringRef name, NonSuccessorInputReplacementBuilderFn replBuilderFn, PatternMatcherFn matcherFn, PatternBenefit benefit = 1) : RewritePattern(name, benefit, context), replBuilderFn(replBuilderFn), matcherFn(matcherFn) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Check if the pattern is applicable to the given operation. if (failed(matcherFn(op))) return rewriter.notifyMatchFailure(op, "pattern not applicable"); // Patterns without recursive memory effects could have side effects, so // it is not safe to fold such ops away. if (!op->hasTrait()) return rewriter.notifyMatchFailure( op, "pattern not applicable to ops without recursive memory effects"); // Find the single acyclic path through the region branch op. auto regionBranchOp = cast(op); SmallVector path = computeSingleAcyclicRegionBranchPath(regionBranchOp); if (path.empty()) return rewriter.notifyMatchFailure( op, "failed to find acyclic region branch path"); // Inline all regions on the path into the enclosing block. rewriter.setInsertionPoint(op); ArrayRef remainingPath = path; SmallVector successorOperands = llvm::to_vector( regionBranchOp.getEntrySuccessorOperands(remainingPath.front())); while (!remainingPath.empty()) { RegionSuccessor nextSuccessor = remainingPath.consume_front(); ValueRange successorInputs = regionBranchOp.getSuccessorInputs(nextSuccessor); assert(successorInputs.size() == successorOperands.size() && "size mismatch"); // Find the index of the first block argument / op result that is a // succesor input. unsigned firstSuccessorInputIdx = 0; if (!successorInputs.empty()) firstSuccessorInputIdx = nextSuccessor.isParent() ? cast(successorInputs.front()).getResultNumber() : cast(successorInputs.front()).getArgNumber(); // Query the total number of block arguments / op results. unsigned numValues = nextSuccessor.isParent() ? op->getNumResults() : nextSuccessor.getSuccessor()->getNumArguments(); // Compute replacement values for all block arguments / op results. SmallVector replacements; // Helper function to get the i-th block argument / op result. auto getValue = [&](unsigned idx) { return nextSuccessor.isParent() ? Value(op->getResult(idx)) : Value(nextSuccessor.getSuccessor()->getArgument(idx)); }; // Compute replacement values for all non-successor-input values that // precede the first successor input. for (unsigned i = 0; i < firstSuccessorInputIdx; ++i) replacements.push_back( replBuilderFn(rewriter, op->getLoc(), getValue(i))); // Use the successor operands of the predecessor as replacement values for // the successor inputs. llvm::append_range(replacements, successorOperands); // Compute replacement values for all block arguments / op results that // succeed the first successor input. for (unsigned i = replacements.size(); i < numValues; ++i) replacements.push_back( replBuilderFn(rewriter, op->getLoc(), getValue(i))); if (nextSuccessor.isParent()) { // The path ends with "parent". Replace the region branch op with the // computed replacement values. assert(remainingPath.empty() && "expected that the path ended"); rewriter.replaceOp(op, replacements); return success(); } // We are inside of a region: query the successor operands from the // terminator, inline the region into the enclosing block, and erase the // terminator. auto terminator = cast( &nextSuccessor.getSuccessor()->front().back()); rewriter.inlineBlockBefore(&nextSuccessor.getSuccessor()->front(), op->getBlock(), op->getIterator(), replacements); successorOperands = llvm::to_vector( terminator.getSuccessorOperands(remainingPath.front())); rewriter.eraseOp(terminator); } llvm_unreachable("expected that paths ends with parent"); } NonSuccessorInputReplacementBuilderFn replBuilderFn; PatternMatcherFn matcherFn; }; } // namespace void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns( RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) { patterns.add(patterns.getContext(), opName, benefit); } void mlir::populateRegionBranchOpInterfaceInliningPattern( RewritePatternSet &patterns, StringRef opName, NonSuccessorInputReplacementBuilderFn replBuilderFn, PatternMatcherFn matcherFn, PatternBenefit benefit) { patterns.add(patterns.getContext(), opName, replBuilderFn, matcherFn, benefit); }