//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the dataflow analysis class for integer range inference // which is used in transformations over the `arith` dialect such as // branch elimination or signed->unsigned rewriting // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include #include #include #define DEBUG_TYPE "int-range-analysis" using namespace mlir; using namespace mlir::dataflow; namespace mlir::dataflow { LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { auto *result = solver.lookupState(v); if (!result || result->getValue().isUninitialized()) return failure(); const ConstantIntRanges &range = result->getValue().getValue(); return success(range.smin().isNonNegative()); } LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op) { auto nonNegativePred = [&solver](Value v) -> bool { return succeeded(staticallyNonNegative(solver, v)); }; return success(llvm::all_of(op->getOperands(), nonNegativePred) && llvm::all_of(op->getResults(), nonNegativePred)); } } // namespace mlir::dataflow LogicalResult IntegerRangeAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { auto inferrable = dyn_cast(op); if (!inferrable) { setAllToEntryStates(results); return success(); } LDBG() << "Inferring ranges for " << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector( operands, [](const IntegerValueRangeLattice *lattice) { return lattice->getValue(); }); auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { auto result = dyn_cast(v); if (!result) return; assert(llvm::is_contained(op->getResults(), result)); LDBG() << "Inferred range " << attrs; IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; IntegerValueRange oldRange = lattice->getValue(); ChangeResult changed = lattice->join(attrs); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedResult && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; inferrable.inferResultRangesFromOptional(argRanges, joinCallback); return success(); } void IntegerRangeAnalysis::visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, ValueRange nonSuccessorInputs, ArrayRef nonSuccessorInputLattices) { assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() && "size mismatch"); if (auto inferrable = dyn_cast(op)) { LDBG() << "Inferring ranges for " << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); }); auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { auto arg = dyn_cast(v); if (!arg) return; if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) return; LDBG() << "Inferred range " << attrs; auto it = llvm::find(successor.getSuccessor()->getArguments(), arg); unsigned nonSuccessorInputIdx = std::distance(successor.getSuccessor()->getArguments().begin(), it); IntegerValueRangeLattice *lattice = nonSuccessorInputLattices[nonSuccessorInputIdx]; IntegerValueRange oldRange = lattice->getValue(); ChangeResult changed = lattice->join(attrs); // Catch loop results with loop variant bounds and conservatively make // them [-inf, inf] so we don't circle around infinitely often (because // the dataflow analysis in MLIR doesn't attempt to work out trip counts // and often can't). bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { return op->hasTrait(); }); if (isYieldedValue && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); }; inferrable.inferResultRangesFromOptional(argRanges, joinCallback); return; } /// Given a lower bound, upper bound, or step from a LoopLikeInterface return /// the lower/upper bound for that result if possible. auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); if (auto attr = dyn_cast(loopBound)) { if (auto bound = dyn_cast(attr)) return bound.getValue(); } else if (auto value = llvm::dyn_cast(loopBound)) { const IntegerValueRangeLattice *lattice = getLatticeElementFor(getProgramPointBefore(block), value); if (lattice != nullptr && !lattice->getValue().isUninitialized()) return getUpper ? lattice->getValue().getValue().smax() : lattice->getValue().getValue().smin(); } // Given the results of getConstant{Lower,Upper}Bound() // or getConstantStep() on a LoopLikeInterface return the lower/upper // bound return getUpper ? APInt::getSignedMaxValue(width) : APInt::getSignedMinValue(width); }; // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast(op)) { std::optional> maybeIvs = loop.getLoopInductionVars(); if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, nonSuccessorInputs, nonSuccessorInputLattices); } // Some loop implementations may return nullopt for non-constant bounds // (e.g. affine.for with a dynamic upper bound), even when induction // variables exist. Fall back to the generic analysis in that case. std::optional> maybeLowerBounds = loop.getLoopLowerBounds(); std::optional> maybeUpperBounds = loop.getLoopUpperBounds(); std::optional> maybeSteps = loop.getLoopSteps(); if (!maybeLowerBounds || !maybeUpperBounds || !maybeSteps) { return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( op, successor, nonSuccessorInputs, nonSuccessorInputLattices); } SmallVector lowerBounds = *maybeLowerBounds; SmallVector upperBounds = *maybeUpperBounds; SmallVector steps = *maybeSteps; for (auto [iv, lowerBound, upperBound, step] : llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) { Block *block = iv.getParentBlock(); APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block, /*getUpper=*/false); APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block, /*getUpper=*/true); // Assume positivity for uniscoverable steps by way of getUpper = true. APInt stepVal = getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true); if (stepVal.isNegative()) { std::swap(min, max); } else { // Correct the upper bound by subtracting 1 so that it becomes a <= // bound, because loops do not generally include their upper bound. max -= 1; } // If we infer the lower bound to be larger than the upper bound, the // resulting range is meaningless and should not be used in further // inferences. if (max.sge(min)) { IntegerValueRangeLattice *ivEntry = getLatticeElement(iv); auto ivRange = ConstantIntRanges::fromSigned(min, max); propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); } } return; } return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( op, successor, nonSuccessorInputs, nonSuccessorInputLattices); }