The vast majority of rewrite / conversion patterns uses a combined `matchAndRewrite` instead of separate `match` and `rewrite` functions. This PR optimizes the code base for the most common case where users implement a combined `matchAndRewrite`. There are no longer any `match` and `rewrite` functions in `RewritePattern`, `ConversionPattern` and their derived classes. Instead, there is a `SplitMatchAndRewriteImpl` class that implements `matchAndRewrite` in terms of `match` and `rewrite`. Details: * The `RewritePattern` and `ConversionPattern` classes are simpler (fewer functions). Especially the `ConversionPattern` class, which now has 5 fewer functions. (There were various `rewrite` overloads to account for 1:1 / 1:N patterns.) * There is a new class `SplitMatchAndRewriteImpl` that derives from `RewritePattern` / `OpRewritePatern` / ..., along with a type alias `RewritePattern::SplitMatchAndRewrite` for convenience. * Fewer `llvm_unreachable` are needed throughout the code base. Instead, we can use pure virtual functions. (In cases where users previously had to implement `rewrite` or `matchAndRewrite`, etc.) * This PR may also improve the number of [`-Woverload-virtual` warnings](https://discourse.llvm.org/t/matchandrewrite-hiding-virtual-functions/84933) that are produced by GCC. (To be confirmed...) Note for LLVM integration: Patterns with separate `match` / `rewrite` implementations, must derive from `X::SplitMatchAndRewrite` instead of `X`. --------- Co-authored-by: River Riddle <riddleriver@gmail.com>
56 lines
1.9 KiB
C++
56 lines
1.9 KiB
C++
//===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
#include "../../test/lib/Dialect/Test/TestDialect.h"
|
|
#include "../../test/lib/Dialect/Test/TestOps.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct AnOpRewritePattern : OpRewritePattern<test::OpA> {
|
|
AnOpRewritePattern(MLIRContext *context)
|
|
: OpRewritePattern(context, /*benefit=*/1,
|
|
/*generatedNames=*/{test::OpB::getOperationName()}) {}
|
|
|
|
LogicalResult matchAndRewrite(test::OpA op,
|
|
PatternRewriter &rewriter) const override {
|
|
return failure();
|
|
}
|
|
};
|
|
TEST(OpRewritePatternTest, GetGeneratedNames) {
|
|
MLIRContext context;
|
|
AnOpRewritePattern pattern(&context);
|
|
ArrayRef<OperationName> ops = pattern.getGeneratedOps();
|
|
|
|
ASSERT_EQ(ops.size(), 1u);
|
|
ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName());
|
|
}
|
|
} // end anonymous namespace
|
|
|
|
namespace {
|
|
LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) {
|
|
return failure();
|
|
}
|
|
TEST(AnOpRewritePatternTest, PatternFuncAttributes) {
|
|
MLIRContext context;
|
|
RewritePatternSet patterns(&context);
|
|
|
|
patterns.add(anOpRewritePatternFunc, /*benefit=*/3,
|
|
/*generatedNames=*/{test::OpB::getOperationName()});
|
|
ASSERT_EQ(patterns.getNativePatterns().size(), 1U);
|
|
auto &pattern = patterns.getNativePatterns().front();
|
|
ASSERT_EQ(pattern->getBenefit(), 3);
|
|
ASSERT_EQ(pattern->getGeneratedOps().size(), 1U);
|
|
ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(),
|
|
test::OpB::getOperationName());
|
|
}
|
|
} // end anonymous namespace
|