Files
llvm-project/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
Hanchenng Wu 4e6ee0b674 [MLIR] Fix test failures for generate-runtime-verification pass from PR #160331 (#162533)
[MLIR] Fix test failures for generate-runtime-verification pass from PR #160331
    
PR #160331 introduced a mistake that removed the error message for
generate-runtime-verification
pass, leading to test failures during
`test-build-check-mlir-build-only-check-mlir`.
    
This patch restores the missing error message.
    
In addition, for related tests, the op strings used in FileChecks are
updated with the same op
formats as used in input mlirs.
    
Verified locally.
    
Fixes post-merge regression from:
https://github.com/llvm/llvm-project/pull/160331
2025-10-08 23:31:55 +01:00

106 lines
3.5 KiB
C++

//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AsmState.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
namespace mlir {
#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
struct GenerateRuntimeVerificationPass
: public impl::GenerateRuntimeVerificationBase<
GenerateRuntimeVerificationPass> {
void runOnOperation() override;
};
/// Default error message generator for runtime verification failures.
///
/// This class generates error messages with different levels of verbosity:
/// - Level 0: Shows only the error message and operation location
/// - Level 1: Shows the full operation string, error message, and location
///
/// Clients can call getVerboseLevel() to retrieve the current verbose level
/// and use it to customize their own error message generators with similar
/// behavior patterns.
class DefaultErrMsgGenerator {
private:
unsigned vLevel;
AsmState &state;
public:
DefaultErrMsgGenerator(unsigned verboseLevel, AsmState &asmState)
: vLevel(verboseLevel), state(asmState) {}
std::string operator()(Operation *op, StringRef msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
stream << "ERROR: Runtime op verification failed\n";
if (vLevel == 1) {
op->print(stream, state);
stream << "\n^ " << msg;
} else {
stream << "^ " << msg;
}
stream << "\nLocation: ";
op->getLoc().print(stream);
return buffer;
}
unsigned getVerboseLevel() const { return vLevel; }
};
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
// Check verboseLevel is in range [0, 1].
if (verboseLevel > 1) {
getOperation()->emitError(
"generate-runtime-verification pass: set verboseLevel to 0 or 1");
signalPassFailure();
return;
}
// The implementation of the RuntimeVerifiableOpInterface may create ops that
// can be verified. We don't want to generate verification for IR that
// performs verification, so gather all runtime-verifiable ops first.
SmallVector<RuntimeVerifiableOpInterface> ops;
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
ops.push_back(verifiableOp);
});
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
flags.skipRegions();
flags.useLocalScope();
AsmState state(getOperation(), flags);
// Client can call getVerboseLevel() to fetch verbose level.
DefaultErrMsgGenerator defaultErrMsgGenerator(verboseLevel.getValue(), state);
OpBuilder builder(getOperation()->getContext());
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc(),
defaultErrMsgGenerator);
};
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
return std::make_unique<GenerateRuntimeVerificationPass>();
}