Files
llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp
Maksim Levental 3b1a7479e8 [mlir][Python] remove stray nb::cast (#176299)
In https://github.com/llvm/llvm-project/pull/155114 we removed
`liveOperations` but forgot this line which was being used to invalidate
operations under a transform root, which currently isn't being used for
anything. So remove.

FYI this led to a subtle double free bug after
https://github.com/llvm/llvm-project/pull/175405:

```python
@test_in_context
def check_builtin():
    module = builtin_d.ModuleOp()
    with module.context, ir.Location.unknown():
        transform_module = builtin_d.Module.create()
        transform_module.operation.attributes["transform.with_named_sequence"] = (
            ir.UnitAttr.get()
        )
        with ir.InsertionPoint(transform_module.body):
            named_sequence = NamedSequenceOp("__transform_main", [any_op_t()], [])
            with ir.InsertionPoint(named_sequence.body):
                YieldOp([])
        interp.apply_named_sequence(
            module,
            transform_module.body.operations[0],
            transform_module,
        )
```

with error 

```
python(7436,0x1f95a93c0) malloc: *** error for object 0x6000002b0000: pointer being freed was not allocated
python(7436,0x1f95a93c0) malloc: *** set a breakpoint in malloc_error_break to debug
```

This is because

```
nb::object obj = nb::cast(payloadRoot);
```
is actually equivalent to
```
nb::object obj = nb::cast(payloadRoot, nb::rv_policy::automatic);
```
which is actually equivalent to
```
nb::object obj = nb::cast(payloadRoot, nb::rv_policy::copy);
```
because I changed the API to `PyOperationBase &payloadRoot` i.e., an
lvalue reference and `nb::rv_policy::automatic` decays to
`nb::rv_policy::copy` for [lvalue
refs](https://nanobind.readthedocs.io/en/latest/api_core.html#_CPPv4N8nanobind9rv_policy9automaticE).
2026-01-16 05:37:11 -08:00

120 lines
4.4 KiB
C++

//===- TransformInterpreter.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pybind classes for the transform dialect interpreter.
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform_interpreter {
struct PyTransformOptions {
PyTransformOptions() { options = mlirTransformOptionsCreate(); };
PyTransformOptions(PyTransformOptions &&other) {
options = other.options;
other.options.ptr = nullptr;
}
PyTransformOptions(const PyTransformOptions &) = delete;
~PyTransformOptions() { mlirTransformOptionsDestroy(options); }
MlirTransformOptions options;
};
} // namespace transform_interpreter
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
static void populateTransformInterpreterSubmodule(nb::module_ &m) {
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
using namespace transform_interpreter;
nb::class_<PyTransformOptions>(m, "TransformOptions")
.def(nb::init<>())
.def_prop_rw(
"expensive_checks",
[](const PyTransformOptions &self) {
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
},
[](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
})
.def_prop_rw(
"enforce_single_top_level_transform_op",
[](const PyTransformOptions &self) {
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
self.options);
},
[](PyTransformOptions &self, bool value) {
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
value);
});
m.def(
"apply_named_sequence",
[](PyOperationBase &payloadRoot, PyOperationBase &transformRoot,
PyOperationBase &transformModule, const PyTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(transformRoot.getOperation()));
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot.getOperation(), transformRoot.getOperation(),
transformModule.getOperation(), options.options);
if (mlirLogicalResultIsSuccess(result)) {
// Even in cases of success, we might have diagnostics to report:
std::string msg;
if ((msg = scope.takeMessage()).size() > 0) {
fprintf(stderr,
"Diagnostic generated while applying "
"transform.named_sequence:\n%s",
msg.data());
}
return;
}
throw nb::value_error(
("Failed to apply named transform sequence.\nDiagnostic message " +
scope.takeMessage())
.c_str());
},
nb::arg("payload_root"), nb::arg("transform_root"),
nb::arg("transform_module"),
nb::arg("transform_options") = PyTransformOptions());
m.def(
"copy_symbols_and_merge_into",
[](PyOperationBase &target, PyOperationBase &other) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(target.getOperation()));
MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(
target.getOperation(), other.getOperation());
if (mlirLogicalResultIsFailure(result)) {
throw nb::value_error(
("Failed to merge symbols.\nDiagnostic message " +
scope.takeMessage())
.c_str());
}
},
nb::arg("target"), nb::arg("other"));
}
NB_MODULE(_mlirTransformInterpreter, m) {
m.doc() = "MLIR Transform dialect interpreter functionality.";
populateTransformInterpreterSubmodule(m);
}