Fixes #137509. When `op->fold` returns a Value that is not one of `op`'s operands (e.g. `unrealized_conversion_cast`'s fold returns the inner cast's operand), `SparseConstantPropagation` read that value's lattice without subscribing to it -- so the op was not revisited when the lattice widened and its stale fold result was not updated. Fix by using `getLatticeElementFor(getProgramPointAfter(op), v)` to register the dependency. This matches a few places in `SparseAnalysis.cpp` where the same strategy is used. I'd love to use something even simpler than `unrealized_conversion_cast` operation in the test, but this is what i got when minimizing the reproduction from the original issue (#137509) and i wasn't able to find any operation that would work for this reproduction. Assisted-By: Claude Code
123 lines
4.6 KiB
C++
123 lines
4.6 KiB
C++
//===- ConstantPropagationAnalysis.cpp - Constant propagation analysis ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/DebugLog.h"
|
|
#include <cassert>
|
|
|
|
#define DEBUG_TYPE "constant-propagation"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::dataflow;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantValue
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ConstantValue::print(raw_ostream &os) const {
|
|
if (isUninitialized()) {
|
|
os << "<UNINITIALIZED>";
|
|
return;
|
|
}
|
|
if (getConstantValue() == nullptr) {
|
|
os << "<UNKNOWN>";
|
|
return;
|
|
}
|
|
return getConstantValue().print(os);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SparseConstantPropagation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult SparseConstantPropagation::visitOperation(
|
|
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
|
|
ArrayRef<Lattice<ConstantValue> *> results) {
|
|
LDBG() << "SCP: Visiting operation: " << *op;
|
|
|
|
// Don't try to simulate the results of a region operation as we can't
|
|
// guarantee that folding will be out-of-place. We don't allow in-place
|
|
// folds as the desire here is for simulated execution, and not general
|
|
// folding.
|
|
if (op->getNumRegions()) {
|
|
setAllToEntryStates(results);
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Attribute, 8> constantOperands;
|
|
constantOperands.reserve(op->getNumOperands());
|
|
for (auto *operandLattice : operands) {
|
|
if (operandLattice->getValue().isUninitialized())
|
|
return success();
|
|
constantOperands.push_back(operandLattice->getValue().getConstantValue());
|
|
}
|
|
|
|
// Save the original operands and attributes just in case the operation
|
|
// folds in-place. The constant passed in may not correspond to the real
|
|
// runtime value, so in-place updates are not allowed.
|
|
SmallVector<Value, 8> originalOperands(op->getOperands());
|
|
DictionaryAttr originalAttrs = op->getAttrDictionary();
|
|
|
|
// Simulate the result of folding this operation to a constant. If folding
|
|
// fails or was an in-place fold, mark the results as overdefined.
|
|
SmallVector<OpFoldResult, 8> foldResults;
|
|
foldResults.reserve(op->getNumResults());
|
|
if (failed(op->fold(constantOperands, foldResults))) {
|
|
setAllToEntryStates(results);
|
|
return success();
|
|
}
|
|
|
|
// If the folding was in-place, mark the results as overdefined and reset
|
|
// the operation. We don't allow in-place folds as the desire here is for
|
|
// simulated execution, and not general folding.
|
|
if (foldResults.empty()) {
|
|
op->setOperands(originalOperands);
|
|
op->setAttrs(originalAttrs);
|
|
setAllToEntryStates(results);
|
|
return success();
|
|
}
|
|
|
|
// Merge the fold results into the lattice for this operation.
|
|
assert(foldResults.size() == op->getNumResults() && "invalid result size");
|
|
for (const auto it : llvm::zip(results, foldResults)) {
|
|
Lattice<ConstantValue> *lattice = std::get<0>(it);
|
|
|
|
// Merge in the result of the fold, either a constant or a value.
|
|
OpFoldResult foldResult = std::get<1>(it);
|
|
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
|
|
LDBG() << "Folded to constant: " << attr;
|
|
propagateIfChanged(lattice,
|
|
lattice->join(ConstantValue(attr, op->getDialect())));
|
|
} else {
|
|
Value foldValue = cast<Value>(foldResult);
|
|
LDBG() << "Folded to value: " << foldValue;
|
|
// The folded value may not be an operand of `op`, so we need to use
|
|
// `getLatticeElementFor` (and not `getLatticeElement`) so that
|
|
// this operation is revisited if that value's lattice widens later.
|
|
AbstractSparseForwardDataFlowAnalysis::join(
|
|
lattice, *getLatticeElementFor(getProgramPointAfter(op), foldValue));
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void SparseConstantPropagation::setToEntryState(
|
|
Lattice<ConstantValue> *lattice) {
|
|
propagateIfChanged(lattice,
|
|
lattice->join(ConstantValue::getUnknownConstant()));
|
|
}
|