115 lines
4.1 KiB
C++
115 lines
4.1 KiB
C++
//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines transform dialect operations for testing MLIR
|
|
// transformations
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.h.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.cpp.inc"
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
|
|
TransformResults &TransformResults,
|
|
TransformState &state) {
|
|
Operation *op = *state.getPayloadOps(getOp()).begin();
|
|
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
|
|
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
|
|
auto *listener =
|
|
cast<ErrorCheckingTrackingListener>(rewriter.getListener());
|
|
std::string errorMsg = listener->getLatestMatchFailureMessage();
|
|
(void)emitRemark(errorMsg);
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
|
|
TransformResults &TransformResults,
|
|
TransformState &state) {
|
|
SmallVector<Value> values;
|
|
for (auto tdValue : getValues()) {
|
|
values.push_back(*state.getPayloadValues(tdValue).begin());
|
|
}
|
|
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
|
|
if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
|
|
auto *listener =
|
|
cast<ErrorCheckingTrackingListener>(rewriter.getListener());
|
|
std::string errorMsg = listener->getLatestMatchFailureMessage();
|
|
(void)emitRemark(errorMsg);
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test affine functionality.
|
|
//===----------------------------------------------------------------------===//
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMakeComposedFoldedAffineApply::applyToOne(
|
|
TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
|
|
ApplyToEachResultList &results, TransformState &state) {
|
|
Location loc = affineApplyOp.getLoc();
|
|
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, affineApplyOp.getAffineMap(),
|
|
getAsOpFoldResult(affineApplyOp.getOperands()),
|
|
/*composeAffineMin=*/true);
|
|
Value result;
|
|
if (auto v = dyn_cast<Value>(ofr)) {
|
|
result = v;
|
|
} else {
|
|
result = arith::ConstantIndexOp::create(rewriter, loc,
|
|
getConstantIntValue(ofr).value());
|
|
}
|
|
results.push_back(result.getDefiningOp());
|
|
rewriter.replaceOp(affineApplyOp, result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Extension
|
|
//===----------------------------------------------------------------------===//
|
|
namespace {
|
|
|
|
class TestTransformsDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformsDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "TestTransformsOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace test {
|
|
void registerTestTransformsTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformsDialectExtension>();
|
|
}
|
|
} // namespace test
|