[mlir] Add analysis filter in dataflow solver (#192998)
Adds an optional filtering control function to the dataflow solver's
initializeAndRun callback, which controls which analyses will be
initialized when running the solver. This makes it possible to reuse
existing dataflow solver instances that have already run to a fixpoint
without re-initializing all of the analyses that have already converged.
A new analysis and test pass is also added, which illustrates how the
filtering can be useful to run a staged analysis, which would not have
been possible before. The example analysis, called `BarAnalysis`,
depends on the converged state of the `FooAnalysis`. The Bar analysis is
a forward analysis that tracks, for each program point, whether any of
the preceding program points hold a `foo_state` that is divisible by 4.
In the example test, the control flow graph looks like the following:
```
entry-block
/ \
bb0 bb2
\ /
bb1
```
The `foo_state` of `bb1` depends on the `foo_state` of `bb0` and `bb2`.
If the solver goes through `bb0->bb1` before `bb2->bb1`, then there is
an intermediate stage in the analyses where the state of `bb1` could be
divisible by 4, even though the final state of `bb1` will not be
divisible by 4 in the converged state. If the `BarAnalysis` runs on
`bb1` in this intermediate state, then it will get stuck with the
"divisible by 4" state, and the analysis will not yield the desired
results.
This PR ensures that the `BarAnalysis` will see the correct state
`foo_state`, because the `FooAnalysis` will fully run to a fixpoint
before the `BarAnalysis` is loaded, initialized, and run.
The Foo and Bar analyses are just trivial examples, but this pattern is
useful when there are analyses that can be made more effective by using
complementary analyses like integer range/divisibility analyses.
**Note for integration:**
DataFlowSolver::load now stores the concrete analysis TypeID, exposed
via DataFlowAnalysis::getTypeID(). Downstream DataFlowAnalysis
subclasses defined in anonymous namespaces must add
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ClassName) to their class
body.
Assisted-by: Codex (gpt-5.4)
This commit is contained in:
@@ -156,6 +156,8 @@ public:
|
||||
class AllocationAnalysis
|
||||
: public mlir::dataflow::DenseForwardDataFlowAnalysis<LatticePoint> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AllocationAnalysis)
|
||||
|
||||
using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
|
||||
|
||||
mlir::LogicalResult visitOperation(mlir::Operation *op,
|
||||
@@ -535,8 +537,7 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
|
||||
candidateOps.insert({allocmem, insertionPoint});
|
||||
}
|
||||
|
||||
LLVM_DEBUG(for (auto [allocMemOp, _]
|
||||
: candidateOps) {
|
||||
LLVM_DEBUG(for (auto [allocMemOp, _] : candidateOps) {
|
||||
llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
|
||||
});
|
||||
return mlir::success();
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
#include "llvm/Support/TypeName.h"
|
||||
@@ -331,9 +332,16 @@ public:
|
||||
template <typename AnalysisT, typename... Args>
|
||||
AnalysisT *load(Args &&...args);
|
||||
|
||||
/// Initialize the children analyses starting from the provided top-level
|
||||
/// operation and run the analysis until fixpoint.
|
||||
LogicalResult initializeAndRun(Operation *top);
|
||||
/// Initialize analyses starting from the provided top-level operation and
|
||||
/// run the analysis until fixpoint.
|
||||
///
|
||||
/// An optional \p analysisFilter predicate restricts which analyses are
|
||||
/// initialized. When no filter is given every loaded analysis is
|
||||
/// (re-)initialized. The fixpoint loop always processes all enqueued work
|
||||
/// items regardless of the filter.
|
||||
LogicalResult initializeAndRun(
|
||||
Operation *top,
|
||||
llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter = nullptr);
|
||||
|
||||
/// Lookup an analysis state for the given lattice anchor. Returns null if one
|
||||
/// does not exist.
|
||||
@@ -574,6 +582,12 @@ void DataFlowSolver::eraseState(AnchorT anchor) {
|
||||
/// an initial dependency graph (and optionally provide an initial state) when
|
||||
/// initialized and define transfer functions when visiting program points.
|
||||
///
|
||||
/// Subclasses defined in anonymous namespaces must provide an explicit TypeID
|
||||
/// via `MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` in their class body.
|
||||
/// This is required because `DataFlowSolver::load` resolves the analysis
|
||||
/// TypeID at load time, and the implicit TypeID fallback is not supported for
|
||||
/// classes in anonymous namespaces.
|
||||
///
|
||||
/// In classical data-flow analysis, the dependency graph is fixed and analyses
|
||||
/// define explicit transfer functions between input states and output states.
|
||||
/// In this framework, however, the dependency graph can change during the
|
||||
@@ -630,6 +644,12 @@ public:
|
||||
/// necessarily identical under the corrensponding lattice type.
|
||||
virtual void initializeEquivalentLatticeAnchor(Operation *top) {}
|
||||
|
||||
/// Return the TypeID of the concrete analysis class. Valid only after
|
||||
/// `DataFlowSolver::load<AnalysisT>` has returned; must not be called from
|
||||
/// the analysis constructor body because the TypeID is set by `load` after
|
||||
/// construction.
|
||||
TypeID getTypeID() const { return analysisTypeID; }
|
||||
|
||||
protected:
|
||||
/// Create a dependency between the given analysis state and lattice anchor
|
||||
/// on this analysis.
|
||||
@@ -705,6 +725,11 @@ private:
|
||||
/// The parent data-flow solver.
|
||||
DataFlowSolver &solver;
|
||||
|
||||
/// The TypeID of the concrete analysis class. Set by
|
||||
/// `DataFlowSolver::load` after construction; not available during the
|
||||
/// analysis constructor.
|
||||
TypeID analysisTypeID;
|
||||
|
||||
/// Allow the data-flow solver to access the internals of this class.
|
||||
friend class DataFlowSolver;
|
||||
};
|
||||
@@ -712,6 +737,7 @@ private:
|
||||
template <typename AnalysisT, typename... Args>
|
||||
AnalysisT *DataFlowSolver::load(Args &&...args) {
|
||||
childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
|
||||
childAnalyses.back()->analysisTypeID = TypeID::get<AnalysisT>();
|
||||
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
childAnalyses.back()->debugName = llvm::getTypeName<AnalysisT>();
|
||||
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
|
||||
@@ -109,7 +109,9 @@ Location LatticeAnchor::getLoc() const {
|
||||
// DataFlowSolver
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
|
||||
LogicalResult DataFlowSolver::initializeAndRun(
|
||||
Operation *top,
|
||||
llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter) {
|
||||
// Enable enqueue to the worklist.
|
||||
isRunning = true;
|
||||
llvm::scope_exit guard([&]() { isRunning = false; });
|
||||
@@ -120,13 +122,21 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
|
||||
if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
|
||||
config.setInterprocedural(false);
|
||||
|
||||
auto shouldInitialize = [&](DataFlowAnalysis &analysis) {
|
||||
return !analysisFilter || analysisFilter(analysis);
|
||||
};
|
||||
|
||||
// Initialize equivalent lattice anchors.
|
||||
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
|
||||
if (!shouldInitialize(analysis))
|
||||
continue;
|
||||
analysis.initializeEquivalentLatticeAnchor(top);
|
||||
}
|
||||
|
||||
// Initialize the analyses.
|
||||
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
|
||||
if (!shouldInitialize(analysis))
|
||||
continue;
|
||||
DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
|
||||
if (failed(analysis.initialize(top)))
|
||||
return failure();
|
||||
|
||||
@@ -310,6 +310,9 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
|
||||
/// (other) consumers.
|
||||
class LayoutInfoPropagation
|
||||
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoPropagation)
|
||||
|
||||
private:
|
||||
xegpu::LayoutKind layoutKind;
|
||||
unsigned indexBitWidth;
|
||||
|
||||
50
mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
Normal file
50
mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
Normal file
@@ -0,0 +1,50 @@
|
||||
// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(test-staged-analyses))' %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @linear()
|
||||
func.func @linear() {
|
||||
// CHECK: "test.foo"() {bar_state = true, foo = 1 : ui64, foo_state = 1 : i64, tag = "annotate"} : () -> ()
|
||||
"test.foo"() {tag = "annotate", foo = 1 : ui64} : () -> ()
|
||||
// CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 3 : i64, tag = "annotate"} : () -> ()
|
||||
"test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
|
||||
// CHECK: "test.foo"() {bar_state = true, foo_state = 3 : i64, tag = "annotate"} : () -> ()
|
||||
"test.foo"() {tag = "annotate"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// This demonstrates why `BarAnalysis` should be run only after `FooAnalysis`
|
||||
// converges.
|
||||
//
|
||||
// Under the current `FooAnalysis` implementation:
|
||||
// - entry op after-state is 0 xor 7 = 7
|
||||
// - bb0 terminator after-state is 7 xor 1 = 6
|
||||
// - when the join block is first visited, only bb0 has contributed, so the
|
||||
// join op transiently sees 6 xor 2 = 4
|
||||
// - once the other predecessor arrives, revisiting the join updates the
|
||||
// final staged `foo_state` to 7 for the first op in the join block and it
|
||||
// stays 7 for the following op
|
||||
//
|
||||
// But if a non-staged `BarAnalysis` observed bb2 after only bb0 had reached
|
||||
// it, bb2's first tagged op would transiently see 6 xor 2 = 4 and latch
|
||||
// `bar_state = false`, poisoning later points. The staged run below must use
|
||||
// only the converged `FooState`, so `bar_state` stays true.
|
||||
//
|
||||
// CHECK-LABEL: func.func @requires_staged_bar()
|
||||
func.func @requires_staged_bar() {
|
||||
// CHECK: "test.branch"()[^bb{{[0-9]+}}, ^bb{{[0-9]+}}] {bar_state = true, foo = 7 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
|
||||
"test.branch"() [^bb0, ^bb2] {tag = "annotate", foo = 7 : ui64} : () -> ()
|
||||
|
||||
^bb0:
|
||||
// CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 1 : ui64, foo_state = 6 : i64, tag = "annotate"} : () -> ()
|
||||
"test.branch"() [^bb1] {tag = "annotate", foo = 1 : ui64} : () -> ()
|
||||
|
||||
^bb1:
|
||||
// CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
|
||||
"test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
|
||||
// CHECK: "test.foo"() {bar_state = true, foo_state = 7 : i64, tag = "annotate"} : () -> ()
|
||||
"test.foo"() {tag = "annotate"} : () -> ()
|
||||
return
|
||||
|
||||
^bb2:
|
||||
// CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 2 : ui64, foo_state = 5 : i64, tag = "annotate"} : () -> ()
|
||||
"test.branch"() [^bb1] {tag = "annotate", foo = 2 : ui64} : () -> ()
|
||||
}
|
||||
@@ -67,6 +67,8 @@ namespace {
|
||||
/// This is a simple analysis that implements a transfer function for constant
|
||||
/// operations.
|
||||
struct ConstantAnalysis : public DataFlowAnalysis {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantAnalysis)
|
||||
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
LogicalResult initialize(Operation *top) override {
|
||||
|
||||
@@ -51,6 +51,8 @@ public:
|
||||
|
||||
class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccessAnalysis)
|
||||
|
||||
NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
|
||||
bool assumeFuncReads = false)
|
||||
: DenseBackwardDataFlowAnalysis(solver, symbolTable),
|
||||
|
||||
@@ -49,6 +49,8 @@ public:
|
||||
class LastModifiedAnalysis
|
||||
: public DenseForwardDataFlowAnalysis<LastModification> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModifiedAnalysis)
|
||||
|
||||
explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
|
||||
: DenseForwardDataFlowAnalysis(solver),
|
||||
assumeFuncWrites(assumeFuncWrites) {}
|
||||
|
||||
@@ -70,6 +70,8 @@ struct WrittenTo : public Lattice<WrittenToLatticeValue> {
|
||||
/// is eventually written to.
|
||||
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenToAnalysis)
|
||||
|
||||
WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
|
||||
bool assumeFuncWrites)
|
||||
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
|
||||
|
||||
@@ -8,12 +8,18 @@
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include <optional>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
constexpr char kTagAttrName[] = "tag";
|
||||
constexpr char kFooAttrName[] = "foo";
|
||||
constexpr char kFooStateAttrName[] = "foo_state";
|
||||
constexpr char kBarStateAttrName[] = "bar_state";
|
||||
|
||||
/// This analysis state represents an integer that is XOR'd with other states.
|
||||
class FooState : public AnalysisState {
|
||||
public:
|
||||
@@ -74,6 +80,78 @@ public:
|
||||
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
static bool classof(const DataFlowAnalysis *a) {
|
||||
return a->getTypeID() == TypeID::get<FooAnalysis>();
|
||||
}
|
||||
|
||||
LogicalResult initialize(Operation *top) override;
|
||||
LogicalResult visit(ProgramPoint *point) override;
|
||||
|
||||
private:
|
||||
void visitBlock(Block *block);
|
||||
void visitOperation(Operation *op);
|
||||
};
|
||||
|
||||
/// This analysis state stores whether all previously observed `FooState`
|
||||
/// values at tagged program points along the CFG leading to the current point
|
||||
/// have been non-multiples of 4. Once the state becomes false at some point,
|
||||
/// all later points reachable from it also remain false.
|
||||
class BarState : public AnalysisState {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarState)
|
||||
|
||||
using AnalysisState::AnalysisState;
|
||||
|
||||
bool isUninitialized() const { return !state; }
|
||||
|
||||
void print(raw_ostream &os) const override {
|
||||
if (!state) {
|
||||
os << "none";
|
||||
return;
|
||||
}
|
||||
os << (*state ? "true" : "false");
|
||||
}
|
||||
|
||||
ChangeResult join(const BarState &rhs) {
|
||||
if (rhs.isUninitialized())
|
||||
return ChangeResult::NoChange;
|
||||
return join(rhs.getValue());
|
||||
}
|
||||
|
||||
ChangeResult join(bool value) {
|
||||
if (isUninitialized()) {
|
||||
state = value;
|
||||
return ChangeResult::Change;
|
||||
}
|
||||
bool newValue = *state && value;
|
||||
if (newValue == *state)
|
||||
return ChangeResult::NoChange;
|
||||
state = newValue;
|
||||
return ChangeResult::Change;
|
||||
}
|
||||
|
||||
bool getValue() const { return *state; }
|
||||
|
||||
private:
|
||||
std::optional<bool> state;
|
||||
};
|
||||
|
||||
/// This analysis is intended to be loaded after `FooAnalysis` has converged.
|
||||
/// It records whether every observed `FooState` on or before a given tagged
|
||||
/// program point has been non-divisible by 4. Because the state only ever
|
||||
/// transitions from true to false, observing a transient divisible-by-4
|
||||
/// `FooState` before `FooAnalysis` converges can permanently poison the
|
||||
/// result.
|
||||
class BarAnalysis : public DataFlowAnalysis {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAnalysis)
|
||||
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
static bool classof(const DataFlowAnalysis *a) {
|
||||
return a->getTypeID() == TypeID::get<BarAnalysis>();
|
||||
}
|
||||
|
||||
LogicalResult initialize(Operation *top) override;
|
||||
LogicalResult visit(ProgramPoint *point) override;
|
||||
|
||||
@@ -90,6 +168,15 @@ struct TestFooAnalysisPass
|
||||
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
struct TestStagedAnalysesPass
|
||||
: public PassWrapper<TestStagedAnalysesPass, OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStagedAnalysesPass)
|
||||
|
||||
StringRef getArgument() const override { return "test-staged-analyses"; }
|
||||
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult FooAnalysis::initialize(Operation *top) {
|
||||
@@ -151,13 +238,76 @@ void FooAnalysis::visitOperation(Operation *op) {
|
||||
result |= state->set(*prevState);
|
||||
|
||||
// Modify the state with the attribute, if specified.
|
||||
if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
|
||||
if (auto attr = op->getAttrOfType<IntegerAttr>(kFooAttrName)) {
|
||||
uint64_t value = attr.getUInt();
|
||||
result |= state->join(value);
|
||||
}
|
||||
propagateIfChanged(state, result);
|
||||
}
|
||||
|
||||
LogicalResult BarAnalysis::initialize(Operation *top) {
|
||||
if (top->getNumRegions() != 1)
|
||||
return top->emitError("expected a single region top-level op");
|
||||
|
||||
if (top->getRegion(0).getBlocks().empty())
|
||||
return top->emitError("expected at least one block in the region");
|
||||
|
||||
// Seed the entry state to true before observing any `FooState`.
|
||||
(void)getOrCreate<BarState>(getProgramPointBefore(&top->getRegion(0).front()))
|
||||
->join(true);
|
||||
|
||||
for (Block &block : top->getRegion(0)) {
|
||||
visitBlock(&block);
|
||||
for (Operation &op : block) {
|
||||
if (op.getNumRegions())
|
||||
return op.emitError("unexpected op with regions");
|
||||
visitOperation(&op);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult BarAnalysis::visit(ProgramPoint *point) {
|
||||
if (!point->isBlockStart())
|
||||
visitOperation(point->getPrevOp());
|
||||
else
|
||||
visitBlock(point->getBlock());
|
||||
return success();
|
||||
}
|
||||
|
||||
void BarAnalysis::visitBlock(Block *block) {
|
||||
if (block->isEntryBlock())
|
||||
return;
|
||||
|
||||
ProgramPoint *point = getProgramPointBefore(block);
|
||||
BarState *state = getOrCreate<BarState>(point);
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
for (Block *pred : block->getPredecessors()) {
|
||||
const BarState *predState = getOrCreateFor<BarState>(
|
||||
point, getProgramPointAfter(pred->getTerminator()));
|
||||
result |= state->join(*predState);
|
||||
}
|
||||
propagateIfChanged(state, result);
|
||||
}
|
||||
|
||||
void BarAnalysis::visitOperation(Operation *op) {
|
||||
ProgramPoint *point = getProgramPointAfter(op);
|
||||
BarState *state = getOrCreate<BarState>(point);
|
||||
ChangeResult result = ChangeResult::NoChange;
|
||||
|
||||
const BarState *prevState =
|
||||
getOrCreateFor<BarState>(point, getProgramPointBefore(op));
|
||||
result |= state->join(*prevState);
|
||||
|
||||
if (op->hasAttr(kTagAttrName)) {
|
||||
const FooState *fooState = getOrCreateFor<FooState>(point, point);
|
||||
if (fooState->isUninitialized())
|
||||
return;
|
||||
result |= state->join((fooState->getValue() & 0x3) != 0);
|
||||
}
|
||||
propagateIfChanged(state, result);
|
||||
}
|
||||
|
||||
void TestFooAnalysisPass::runOnOperation() {
|
||||
func::FuncOp func = getOperation();
|
||||
DataFlowSolver solver;
|
||||
@@ -169,7 +319,7 @@ void TestFooAnalysisPass::runOnOperation() {
|
||||
os << "function: @" << func.getSymName() << "\n";
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
auto tag = op->getAttrOfType<StringAttr>("tag");
|
||||
auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
|
||||
if (!tag)
|
||||
return;
|
||||
const FooState *state =
|
||||
@@ -179,8 +329,39 @@ void TestFooAnalysisPass::runOnOperation() {
|
||||
});
|
||||
}
|
||||
|
||||
void TestStagedAnalysesPass::runOnOperation() {
|
||||
func::FuncOp func = getOperation();
|
||||
Builder builder(func.getContext());
|
||||
|
||||
DataFlowSolver solver;
|
||||
solver.load<FooAnalysis>();
|
||||
if (failed(solver.initializeAndRun(func)))
|
||||
return signalPassFailure();
|
||||
solver.load<BarAnalysis>();
|
||||
if (failed(solver.initializeAndRun(func, llvm::IsaPred<BarAnalysis>)))
|
||||
return signalPassFailure();
|
||||
|
||||
func.walk([&](Operation *op) {
|
||||
if (!op->hasAttr(kTagAttrName))
|
||||
return;
|
||||
|
||||
ProgramPoint *point = solver.getProgramPointAfter(op);
|
||||
const FooState *fooState = solver.lookupState<FooState>(point);
|
||||
const BarState *barState = solver.lookupState<BarState>(point);
|
||||
assert(fooState && !fooState->isUninitialized());
|
||||
assert(barState && !barState->isUninitialized());
|
||||
|
||||
op->setAttr(kFooStateAttrName,
|
||||
builder.getI64IntegerAttr(fooState->getValue()));
|
||||
op->setAttr(kBarStateAttrName, builder.getBoolAttr(barState->getValue()));
|
||||
});
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
|
||||
void registerTestStagedAnalysesPass() {
|
||||
PassRegistration<TestStagedAnalysesPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
||||
@@ -99,6 +99,7 @@ void registerTestDynamicPipelinePass();
|
||||
void registerTestRemarkPass();
|
||||
void registerTestEmulateNarrowTypePass();
|
||||
void registerTestFooAnalysisPass();
|
||||
void registerTestStagedAnalysesPass();
|
||||
void registerTestComposeSubView();
|
||||
void registerTestMultiBuffering();
|
||||
void registerTestIRVisitorsPass();
|
||||
@@ -247,6 +248,7 @@ static void registerTestPasses() {
|
||||
mlir::test::registerTestRemarkPass();
|
||||
mlir::test::registerTestEmulateNarrowTypePass();
|
||||
mlir::test::registerTestFooAnalysisPass();
|
||||
mlir::test::registerTestStagedAnalysesPass();
|
||||
mlir::test::registerTestComposeSubView();
|
||||
mlir::test::registerTestMultiBuffering();
|
||||
mlir::test::registerTestIRVisitorsPass();
|
||||
|
||||
Reference in New Issue
Block a user