diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 3f49f60089db..cfbd273b88ea 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -156,6 +156,8 @@ public: class AllocationAnalysis : public mlir::dataflow::DenseForwardDataFlowAnalysis { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AllocationAnalysis) + using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis; mlir::LogicalResult visitOperation(mlir::Operation *op, @@ -535,8 +537,7 @@ StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) { candidateOps.insert({allocmem, insertionPoint}); } - LLVM_DEBUG(for (auto [allocMemOp, _] - : candidateOps) { + LLVM_DEBUG(for (auto [allocMemOp, _] : candidateOps) { llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n'; }); return mlir::success(); diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h index 87ec01a918d9..25506645f2f2 100644 --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -20,6 +20,7 @@ #include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/TypeName.h" @@ -331,9 +332,16 @@ public: template AnalysisT *load(Args &&...args); - /// Initialize the children analyses starting from the provided top-level - /// operation and run the analysis until fixpoint. - LogicalResult initializeAndRun(Operation *top); + /// Initialize analyses starting from the provided top-level operation and + /// run the analysis until fixpoint. + /// + /// An optional \p analysisFilter predicate restricts which analyses are + /// initialized. When no filter is given every loaded analysis is + /// (re-)initialized. The fixpoint loop always processes all enqueued work + /// items regardless of the filter. + LogicalResult initializeAndRun( + Operation *top, + llvm::function_ref analysisFilter = nullptr); /// Lookup an analysis state for the given lattice anchor. Returns null if one /// does not exist. @@ -574,6 +582,12 @@ void DataFlowSolver::eraseState(AnchorT anchor) { /// an initial dependency graph (and optionally provide an initial state) when /// initialized and define transfer functions when visiting program points. /// +/// Subclasses defined in anonymous namespaces must provide an explicit TypeID +/// via `MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` in their class body. +/// This is required because `DataFlowSolver::load` resolves the analysis +/// TypeID at load time, and the implicit TypeID fallback is not supported for +/// classes in anonymous namespaces. +/// /// In classical data-flow analysis, the dependency graph is fixed and analyses /// define explicit transfer functions between input states and output states. /// In this framework, however, the dependency graph can change during the @@ -630,6 +644,12 @@ public: /// necessarily identical under the corrensponding lattice type. virtual void initializeEquivalentLatticeAnchor(Operation *top) {} + /// Return the TypeID of the concrete analysis class. Valid only after + /// `DataFlowSolver::load` has returned; must not be called from + /// the analysis constructor body because the TypeID is set by `load` after + /// construction. + TypeID getTypeID() const { return analysisTypeID; } + protected: /// Create a dependency between the given analysis state and lattice anchor /// on this analysis. @@ -705,6 +725,11 @@ private: /// The parent data-flow solver. DataFlowSolver &solver; + /// The TypeID of the concrete analysis class. Set by + /// `DataFlowSolver::load` after construction; not available during the + /// analysis constructor. + TypeID analysisTypeID; + /// Allow the data-flow solver to access the internals of this class. friend class DataFlowSolver; }; @@ -712,6 +737,7 @@ private: template AnalysisT *DataFlowSolver::load(Args &&...args) { childAnalyses.emplace_back(new AnalysisT(*this, std::forward(args)...)); + childAnalyses.back()->analysisTypeID = TypeID::get(); #if LLVM_ENABLE_ABI_BREAKING_CHECKS childAnalyses.back()->debugName = llvm::getTypeName(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 258bcf312afc..e51ae7a1d7ca 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -109,7 +109,9 @@ Location LatticeAnchor::getLoc() const { // DataFlowSolver //===----------------------------------------------------------------------===// -LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { +LogicalResult DataFlowSolver::initializeAndRun( + Operation *top, + llvm::function_ref analysisFilter) { // Enable enqueue to the worklist. isRunning = true; llvm::scope_exit guard([&]() { isRunning = false; }); @@ -120,13 +122,21 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { if (isInterprocedural && !top->hasTrait()) config.setInterprocedural(false); + auto shouldInitialize = [&](DataFlowAnalysis &analysis) { + return !analysisFilter || analysisFilter(analysis); + }; + // Initialize equivalent lattice anchors. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { + if (!shouldInitialize(analysis)) + continue; analysis.initializeEquivalentLatticeAnchor(top); } // Initialize the analyses. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { + if (!shouldInitialize(analysis)) + continue; DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName); if (failed(analysis.initialize(top))) return failure(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 43998ed41f7a..88341e120267 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -310,6 +310,9 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty, /// (other) consumers. class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LayoutInfoPropagation) + private: xegpu::LayoutKind layoutKind; unsigned indexBitWidth; diff --git a/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir new file mode 100644 index 000000000000..5da7ad8bf127 --- /dev/null +++ b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(test-staged-analyses))' %s | FileCheck %s + +// CHECK-LABEL: func.func @linear() +func.func @linear() { + // CHECK: "test.foo"() {bar_state = true, foo = 1 : ui64, foo_state = 1 : i64, tag = "annotate"} : () -> () + "test.foo"() {tag = "annotate", foo = 1 : ui64} : () -> () + // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 3 : i64, tag = "annotate"} : () -> () + "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> () + // CHECK: "test.foo"() {bar_state = true, foo_state = 3 : i64, tag = "annotate"} : () -> () + "test.foo"() {tag = "annotate"} : () -> () + return +} + +// This demonstrates why `BarAnalysis` should be run only after `FooAnalysis` +// converges. +// +// Under the current `FooAnalysis` implementation: +// - entry op after-state is 0 xor 7 = 7 +// - bb0 terminator after-state is 7 xor 1 = 6 +// - when the join block is first visited, only bb0 has contributed, so the +// join op transiently sees 6 xor 2 = 4 +// - once the other predecessor arrives, revisiting the join updates the +// final staged `foo_state` to 7 for the first op in the join block and it +// stays 7 for the following op +// +// But if a non-staged `BarAnalysis` observed bb2 after only bb0 had reached +// it, bb2's first tagged op would transiently see 6 xor 2 = 4 and latch +// `bar_state = false`, poisoning later points. The staged run below must use +// only the converged `FooState`, so `bar_state` stays true. +// +// CHECK-LABEL: func.func @requires_staged_bar() +func.func @requires_staged_bar() { + // CHECK: "test.branch"()[^bb{{[0-9]+}}, ^bb{{[0-9]+}}] {bar_state = true, foo = 7 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> () + "test.branch"() [^bb0, ^bb2] {tag = "annotate", foo = 7 : ui64} : () -> () + +^bb0: + // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 1 : ui64, foo_state = 6 : i64, tag = "annotate"} : () -> () + "test.branch"() [^bb1] {tag = "annotate", foo = 1 : ui64} : () -> () + +^bb1: + // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> () + "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> () + // CHECK: "test.foo"() {bar_state = true, foo_state = 7 : i64, tag = "annotate"} : () -> () + "test.foo"() {tag = "annotate"} : () -> () + return + +^bb2: + // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 2 : ui64, foo_state = 5 : i64, tag = "annotate"} : () -> () + "test.branch"() [^bb1] {tag = "annotate", foo = 2 : ui64} : () -> () +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp index 2dc77c9705d3..327e80787371 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -67,6 +67,8 @@ namespace { /// This is a simple analysis that implements a transfer function for constant /// operations. struct ConstantAnalysis : public DataFlowAnalysis { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConstantAnalysis) + using DataFlowAnalysis::DataFlowAnalysis; LogicalResult initialize(Operation *top) override { diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index 232bf1482755..f09df47e5f65 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -51,6 +51,8 @@ public: class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccessAnalysis) + NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, bool assumeFuncReads = false) : DenseBackwardDataFlowAnalysis(solver, symbolTable), diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp index 9236e9816888..f2384f32948a 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp @@ -49,6 +49,8 @@ public: class LastModifiedAnalysis : public DenseForwardDataFlowAnalysis { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModifiedAnalysis) + explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites) : DenseForwardDataFlowAnalysis(solver), assumeFuncWrites(assumeFuncWrites) {} diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp index 4f19cc7144af..b1978880e2bd 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp @@ -70,6 +70,8 @@ struct WrittenTo : public Lattice { /// is eventually written to. class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenToAnalysis) + WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, bool assumeFuncWrites) : SparseBackwardDataFlowAnalysis(solver, symbolTable), diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp index 4267fb42266c..9af7e205aaee 100644 --- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -8,12 +8,18 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include using namespace mlir; namespace { +constexpr char kTagAttrName[] = "tag"; +constexpr char kFooAttrName[] = "foo"; +constexpr char kFooStateAttrName[] = "foo_state"; +constexpr char kBarStateAttrName[] = "bar_state"; + /// This analysis state represents an integer that is XOR'd with other states. class FooState : public AnalysisState { public: @@ -74,6 +80,78 @@ public: using DataFlowAnalysis::DataFlowAnalysis; + static bool classof(const DataFlowAnalysis *a) { + return a->getTypeID() == TypeID::get(); + } + + LogicalResult initialize(Operation *top) override; + LogicalResult visit(ProgramPoint *point) override; + +private: + void visitBlock(Block *block); + void visitOperation(Operation *op); +}; + +/// This analysis state stores whether all previously observed `FooState` +/// values at tagged program points along the CFG leading to the current point +/// have been non-multiples of 4. Once the state becomes false at some point, +/// all later points reachable from it also remain false. +class BarState : public AnalysisState { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarState) + + using AnalysisState::AnalysisState; + + bool isUninitialized() const { return !state; } + + void print(raw_ostream &os) const override { + if (!state) { + os << "none"; + return; + } + os << (*state ? "true" : "false"); + } + + ChangeResult join(const BarState &rhs) { + if (rhs.isUninitialized()) + return ChangeResult::NoChange; + return join(rhs.getValue()); + } + + ChangeResult join(bool value) { + if (isUninitialized()) { + state = value; + return ChangeResult::Change; + } + bool newValue = *state && value; + if (newValue == *state) + return ChangeResult::NoChange; + state = newValue; + return ChangeResult::Change; + } + + bool getValue() const { return *state; } + +private: + std::optional state; +}; + +/// This analysis is intended to be loaded after `FooAnalysis` has converged. +/// It records whether every observed `FooState` on or before a given tagged +/// program point has been non-divisible by 4. Because the state only ever +/// transitions from true to false, observing a transient divisible-by-4 +/// `FooState` before `FooAnalysis` converges can permanently poison the +/// result. +class BarAnalysis : public DataFlowAnalysis { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAnalysis) + + using DataFlowAnalysis::DataFlowAnalysis; + + static bool classof(const DataFlowAnalysis *a) { + return a->getTypeID() == TypeID::get(); + } + LogicalResult initialize(Operation *top) override; LogicalResult visit(ProgramPoint *point) override; @@ -90,6 +168,15 @@ struct TestFooAnalysisPass void runOnOperation() override; }; + +struct TestStagedAnalysesPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStagedAnalysesPass) + + StringRef getArgument() const override { return "test-staged-analyses"; } + + void runOnOperation() override; +}; } // namespace LogicalResult FooAnalysis::initialize(Operation *top) { @@ -151,13 +238,76 @@ void FooAnalysis::visitOperation(Operation *op) { result |= state->set(*prevState); // Modify the state with the attribute, if specified. - if (auto attr = op->getAttrOfType("foo")) { + if (auto attr = op->getAttrOfType(kFooAttrName)) { uint64_t value = attr.getUInt(); result |= state->join(value); } propagateIfChanged(state, result); } +LogicalResult BarAnalysis::initialize(Operation *top) { + if (top->getNumRegions() != 1) + return top->emitError("expected a single region top-level op"); + + if (top->getRegion(0).getBlocks().empty()) + return top->emitError("expected at least one block in the region"); + + // Seed the entry state to true before observing any `FooState`. + (void)getOrCreate(getProgramPointBefore(&top->getRegion(0).front())) + ->join(true); + + for (Block &block : top->getRegion(0)) { + visitBlock(&block); + for (Operation &op : block) { + if (op.getNumRegions()) + return op.emitError("unexpected op with regions"); + visitOperation(&op); + } + } + return success(); +} + +LogicalResult BarAnalysis::visit(ProgramPoint *point) { + if (!point->isBlockStart()) + visitOperation(point->getPrevOp()); + else + visitBlock(point->getBlock()); + return success(); +} + +void BarAnalysis::visitBlock(Block *block) { + if (block->isEntryBlock()) + return; + + ProgramPoint *point = getProgramPointBefore(block); + BarState *state = getOrCreate(point); + ChangeResult result = ChangeResult::NoChange; + for (Block *pred : block->getPredecessors()) { + const BarState *predState = getOrCreateFor( + point, getProgramPointAfter(pred->getTerminator())); + result |= state->join(*predState); + } + propagateIfChanged(state, result); +} + +void BarAnalysis::visitOperation(Operation *op) { + ProgramPoint *point = getProgramPointAfter(op); + BarState *state = getOrCreate(point); + ChangeResult result = ChangeResult::NoChange; + + const BarState *prevState = + getOrCreateFor(point, getProgramPointBefore(op)); + result |= state->join(*prevState); + + if (op->hasAttr(kTagAttrName)) { + const FooState *fooState = getOrCreateFor(point, point); + if (fooState->isUninitialized()) + return; + result |= state->join((fooState->getValue() & 0x3) != 0); + } + propagateIfChanged(state, result); +} + void TestFooAnalysisPass::runOnOperation() { func::FuncOp func = getOperation(); DataFlowSolver solver; @@ -169,7 +319,7 @@ void TestFooAnalysisPass::runOnOperation() { os << "function: @" << func.getSymName() << "\n"; func.walk([&](Operation *op) { - auto tag = op->getAttrOfType("tag"); + auto tag = op->getAttrOfType(kTagAttrName); if (!tag) return; const FooState *state = @@ -179,8 +329,39 @@ void TestFooAnalysisPass::runOnOperation() { }); } +void TestStagedAnalysesPass::runOnOperation() { + func::FuncOp func = getOperation(); + Builder builder(func.getContext()); + + DataFlowSolver solver; + solver.load(); + if (failed(solver.initializeAndRun(func))) + return signalPassFailure(); + solver.load(); + if (failed(solver.initializeAndRun(func, llvm::IsaPred))) + return signalPassFailure(); + + func.walk([&](Operation *op) { + if (!op->hasAttr(kTagAttrName)) + return; + + ProgramPoint *point = solver.getProgramPointAfter(op); + const FooState *fooState = solver.lookupState(point); + const BarState *barState = solver.lookupState(point); + assert(fooState && !fooState->isUninitialized()); + assert(barState && !barState->isUninitialized()); + + op->setAttr(kFooStateAttrName, + builder.getI64IntegerAttr(fooState->getValue())); + op->setAttr(kBarStateAttrName, builder.getBoolAttr(barState->getValue())); + }); +} + namespace mlir { namespace test { void registerTestFooAnalysisPass() { PassRegistration(); } +void registerTestStagedAnalysesPass() { + PassRegistration(); +} } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 48b8c179bd1b..c4754b3a0855 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -99,6 +99,7 @@ void registerTestDynamicPipelinePass(); void registerTestRemarkPass(); void registerTestEmulateNarrowTypePass(); void registerTestFooAnalysisPass(); +void registerTestStagedAnalysesPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); void registerTestIRVisitorsPass(); @@ -247,6 +248,7 @@ static void registerTestPasses() { mlir::test::registerTestRemarkPass(); mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestFooAnalysisPass(); + mlir::test::registerTestStagedAnalysesPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIRVisitorsPass();