In https://github.com/llvm/llvm-project/pull/155114 we removed `liveOperations` but forgot this line which was being used to invalidate operations under a transform root, which currently isn't being used for anything. So remove. FYI this led to a subtle double free bug after https://github.com/llvm/llvm-project/pull/175405: ```python @test_in_context def check_builtin(): module = builtin_d.ModuleOp() with module.context, ir.Location.unknown(): transform_module = builtin_d.Module.create() transform_module.operation.attributes["transform.with_named_sequence"] = ( ir.UnitAttr.get() ) with ir.InsertionPoint(transform_module.body): named_sequence = NamedSequenceOp("__transform_main", [any_op_t()], []) with ir.InsertionPoint(named_sequence.body): YieldOp([]) interp.apply_named_sequence( module, transform_module.body.operations[0], transform_module, ) ``` with error ``` python(7436,0x1f95a93c0) malloc: *** error for object 0x6000002b0000: pointer being freed was not allocated python(7436,0x1f95a93c0) malloc: *** set a breakpoint in malloc_error_break to debug ``` This is because ``` nb::object obj = nb::cast(payloadRoot); ``` is actually equivalent to ``` nb::object obj = nb::cast(payloadRoot, nb::rv_policy::automatic); ``` which is actually equivalent to ``` nb::object obj = nb::cast(payloadRoot, nb::rv_policy::copy); ``` because I changed the API to `PyOperationBase &payloadRoot` i.e., an lvalue reference and `nb::rv_policy::automatic` decays to `nb::rv_policy::copy` for [lvalue refs](https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4N8nanobind9rv_policy9automaticE).
120 lines
4.4 KiB
C++
120 lines
4.4 KiB
C++
//===- TransformInterpreter.cpp -------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Pybind classes for the transform dialect interpreter.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Dialect/Transform/Interpreter.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir-c/Support.h"
|
|
#include "mlir/Bindings/Python/Diagnostics.h"
|
|
#include "mlir/Bindings/Python/IRCore.h"
|
|
#include "mlir/Bindings/Python/Nanobind.h"
|
|
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
|
|
|
namespace nb = nanobind;
|
|
|
|
namespace mlir {
|
|
namespace python {
|
|
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
|
namespace transform_interpreter {
|
|
struct PyTransformOptions {
|
|
PyTransformOptions() { options = mlirTransformOptionsCreate(); };
|
|
PyTransformOptions(PyTransformOptions &&other) {
|
|
options = other.options;
|
|
other.options.ptr = nullptr;
|
|
}
|
|
PyTransformOptions(const PyTransformOptions &) = delete;
|
|
|
|
~PyTransformOptions() { mlirTransformOptionsDestroy(options); }
|
|
|
|
MlirTransformOptions options;
|
|
};
|
|
} // namespace transform_interpreter
|
|
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
|
} // namespace python
|
|
} // namespace mlir
|
|
|
|
static void populateTransformInterpreterSubmodule(nb::module_ &m) {
|
|
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
|
|
using namespace transform_interpreter;
|
|
nb::class_<PyTransformOptions>(m, "TransformOptions")
|
|
.def(nb::init<>())
|
|
.def_prop_rw(
|
|
"expensive_checks",
|
|
[](const PyTransformOptions &self) {
|
|
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
|
|
},
|
|
[](PyTransformOptions &self, bool value) {
|
|
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
|
|
})
|
|
.def_prop_rw(
|
|
"enforce_single_top_level_transform_op",
|
|
[](const PyTransformOptions &self) {
|
|
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
|
|
self.options);
|
|
},
|
|
[](PyTransformOptions &self, bool value) {
|
|
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
|
|
value);
|
|
});
|
|
|
|
m.def(
|
|
"apply_named_sequence",
|
|
[](PyOperationBase &payloadRoot, PyOperationBase &transformRoot,
|
|
PyOperationBase &transformModule, const PyTransformOptions &options) {
|
|
mlir::python::CollectDiagnosticsToStringScope scope(
|
|
mlirOperationGetContext(transformRoot.getOperation()));
|
|
MlirLogicalResult result = mlirTransformApplyNamedSequence(
|
|
payloadRoot.getOperation(), transformRoot.getOperation(),
|
|
transformModule.getOperation(), options.options);
|
|
if (mlirLogicalResultIsSuccess(result)) {
|
|
// Even in cases of success, we might have diagnostics to report:
|
|
std::string msg;
|
|
if ((msg = scope.takeMessage()).size() > 0) {
|
|
fprintf(stderr,
|
|
"Diagnostic generated while applying "
|
|
"transform.named_sequence:\n%s",
|
|
msg.data());
|
|
}
|
|
return;
|
|
}
|
|
|
|
throw nb::value_error(
|
|
("Failed to apply named transform sequence.\nDiagnostic message " +
|
|
scope.takeMessage())
|
|
.c_str());
|
|
},
|
|
nb::arg("payload_root"), nb::arg("transform_root"),
|
|
nb::arg("transform_module"),
|
|
nb::arg("transform_options") = PyTransformOptions());
|
|
|
|
m.def(
|
|
"copy_symbols_and_merge_into",
|
|
[](PyOperationBase &target, PyOperationBase &other) {
|
|
mlir::python::CollectDiagnosticsToStringScope scope(
|
|
mlirOperationGetContext(target.getOperation()));
|
|
|
|
MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(
|
|
target.getOperation(), other.getOperation());
|
|
if (mlirLogicalResultIsFailure(result)) {
|
|
throw nb::value_error(
|
|
("Failed to merge symbols.\nDiagnostic message " +
|
|
scope.takeMessage())
|
|
.c_str());
|
|
}
|
|
},
|
|
nb::arg("target"), nb::arg("other"));
|
|
}
|
|
|
|
NB_MODULE(_mlirTransformInterpreter, m) {
|
|
m.doc() = "MLIR Transform dialect interpreter functionality.";
|
|
populateTransformInterpreterSubmodule(m);
|
|
}
|