Files
llvm-project/mlir/test/lib/Transforms/TestSingleFold.cpp
Mehdi Amini fcf79e5276 [MLIR] Improve in-place folding to iterate until fixed-point (#160615)
When executed in the context of canonicalization, the folders are
invoked in a fixed-point iterative process. However in the context of an
API like `createOrFold()` or in DialectConversion for example, we expect
a "one-shot" call to fold to be as "folded" as possible. However, even
when folders themselves are indempotent, folders on a given operation
interact with each other. For example:

```
// X = 0 + Y
%X = arith.addi %c_0, %Y : i32
```

should fold to %Y, but the process actually involves first the folder
provided by the IsCommutative trait to move the constant to the right.
However this happens after attempting to fold the operation and the
operation folder isn't attempt again after applying the trait folder.

This commit makes sure we iterate until fixed point on folder
applications.

Fixes #159844
2025-09-27 10:29:42 +02:00

94 lines
3.5 KiB
C++

//===- TestSingleFold.cpp - Pass to test single-pass folding --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
namespace {
/// Test pass for single-pass constant folding.
///
/// This pass tests the behavior of operations when folded exactly once. Unlike
/// canonicalization passes that may apply multiple rounds of folding, this pass
/// ensures that each operation is folded at most once, which is useful for
/// testing scenarios where the fold implementation should handle complex cases
/// without requiring multiple iterations.
///
/// The pass also removes dead constants after folding to clean up unused
/// intermediate results.
struct TestSingleFold : public PassWrapper<TestSingleFold, OperationPass<>>,
public RewriterBase::Listener {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSingleFold)
TestSingleFold() = default;
TestSingleFold(const TestSingleFold &pass) : PassWrapper(pass) {}
StringRef getArgument() const final { return "test-single-fold"; }
StringRef getDescription() const final {
return "Test single-pass operation folding and dead constant elimination";
}
// All constants in the operation post folding.
SmallVector<Operation *> existingConstants;
void foldOperation(Operation *op, OperationFolder &helper);
void runOnOperation() override;
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override {
existingConstants.push_back(op);
}
void notifyOperationErased(Operation *op) override {
auto *it = llvm::find(existingConstants, op);
if (it != existingConstants.end())
existingConstants.erase(it);
}
Option<int> maxIterations{*this, "max-iterations",
llvm::cl::desc("Max iterations in the tryToFold"),
llvm::cl::init(1)};
};
} // namespace
void TestSingleFold::foldOperation(Operation *op, OperationFolder &helper) {
// Attempt to fold the specified operation, including handling unused or
// duplicated constants.
bool inPlaceUpdate = false;
(void)helper.tryToFold(op, &inPlaceUpdate, maxIterations);
}
void TestSingleFold::runOnOperation() {
existingConstants.clear();
// Collect and fold the operations within the operation.
SmallVector<Operation *, 8> ops;
getOperation()->walk<mlir::WalkOrder::PreOrder>(
[&](Operation *op) { ops.push_back(op); });
// Fold the constants in reverse so that the last generated constants from
// folding are at the beginning. This creates somewhat of a linear ordering to
// the newly generated constants that matches the operation order and improves
// the readability of test cases.
OperationFolder helper(&getContext(), /*listener=*/this);
for (Operation *op : llvm::reverse(ops))
foldOperation(op, helper);
// By the time we are done, we may have simplified a bunch of code, leaving
// around dead constants. Check for them now and remove them.
for (auto *cst : existingConstants) {
if (cst->use_empty())
cst->erase();
}
}
namespace mlir {
namespace test {
void registerTestSingleFold() { PassRegistration<TestSingleFold>(); }
} // namespace test
} // namespace mlir