//===- TransformInterpreter.cpp -------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Pybind classes for the transform dialect interpreter. // //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Transform/Interpreter.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { namespace transform_interpreter { struct PyTransformOptions { PyTransformOptions() { options = mlirTransformOptionsCreate(); }; PyTransformOptions(PyTransformOptions &&other) { options = other.options; other.options.ptr = nullptr; } PyTransformOptions(const PyTransformOptions &) = delete; ~PyTransformOptions() { mlirTransformOptionsDestroy(options); } MlirTransformOptions options; }; } // namespace transform_interpreter } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir static void populateTransformInterpreterSubmodule(nb::module_ &m) { using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; using namespace transform_interpreter; nb::class_(m, "TransformOptions") .def(nb::init<>()) .def_prop_rw( "expensive_checks", [](const PyTransformOptions &self) { return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); }, [](PyTransformOptions &self, bool value) { mlirTransformOptionsEnableExpensiveChecks(self.options, value); }) .def_prop_rw( "enforce_single_top_level_transform_op", [](const PyTransformOptions &self) { return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( self.options); }, [](PyTransformOptions &self, bool value) { mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, value); }); m.def( "apply_named_sequence", [](PyOperationBase &payloadRoot, PyOperationBase &transformRoot, PyOperationBase &transformModule, const PyTransformOptions &options) { mlir::python::CollectDiagnosticsToStringScope scope( mlirOperationGetContext(transformRoot.getOperation())); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot.getOperation(), transformRoot.getOperation(), transformModule.getOperation(), options.options); if (mlirLogicalResultIsSuccess(result)) { // Even in cases of success, we might have diagnostics to report: std::string msg; if ((msg = scope.takeMessage()).size() > 0) { fprintf(stderr, "Diagnostic generated while applying " "transform.named_sequence:\n%s", msg.data()); } return; } throw nb::value_error( ("Failed to apply named transform sequence.\nDiagnostic message " + scope.takeMessage()) .c_str()); }, nb::arg("payload_root"), nb::arg("transform_root"), nb::arg("transform_module"), nb::arg("transform_options") = PyTransformOptions()); m.def( "copy_symbols_and_merge_into", [](PyOperationBase &target, PyOperationBase &other) { mlir::python::CollectDiagnosticsToStringScope scope( mlirOperationGetContext(target.getOperation())); MlirLogicalResult result = mlirMergeSymbolsIntoFromClone( target.getOperation(), other.getOperation()); if (mlirLogicalResultIsFailure(result)) { throw nb::value_error( ("Failed to merge symbols.\nDiagnostic message " + scope.takeMessage()) .c_str()); } }, nb::arg("target"), nb::arg("other")); } NB_MODULE(_mlirTransformInterpreter, m) { m.doc() = "MLIR Transform dialect interpreter functionality."; populateTransformInterpreterSubmodule(m); }