//===- Rewrite.cpp - Rewrite ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Rewrite.h" #include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Globals.h" #include "mlir/Bindings/Python/IRCore.h" // clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. // clang-format on #include "mlir/Config/mlir-config.h" #include "nanobind/nanobind.h" #include namespace nb = nanobind; using namespace mlir; using namespace nb::literals; using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { class PyPatternRewriter { public: PyPatternRewriter(MlirPatternRewriter rewriter) : base(mlirPatternRewriterAsBase(rewriter)), ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {} PyInsertionPoint getInsertionPoint() const { MlirBlock block = mlirRewriterBaseGetInsertionBlock(base); MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base); if (mlirOperationIsNull(op)) { MlirOperation owner = mlirBlockGetParentOperation(block); auto parent = PyOperation::forOperation(ctx, owner); return PyInsertionPoint(PyBlock(parent, block)); } return PyInsertionPoint(PyOperation::forOperation(ctx, op)); } void replaceOp(MlirOperation op, MlirOperation newOp) { mlirRewriterBaseReplaceOpWithOperation(base, op, newOp); } void replaceOp(MlirOperation op, const std::vector &values) { mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data()); } void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); } private: MlirRewriterBase base; PyMlirContextRef ctx; }; class PyConversionPatternRewriter : PyPatternRewriter { public: PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter) : PyPatternRewriter( mlirConversionPatternRewriterAsPatternRewriter(rewriter)) {} }; class PyConversionTarget { public: PyConversionTarget(MlirContext context) : target(mlirConversionTargetCreate(context)) {} ~PyConversionTarget() { mlirConversionTargetDestroy(target); } void addLegalOp(const std::string &opName) { mlirConversionTargetAddLegalOp( target, mlirStringRefCreate(opName.data(), opName.size())); } void addIllegalOp(const std::string &opName) { mlirConversionTargetAddIllegalOp( target, mlirStringRefCreate(opName.data(), opName.size())); } void addLegalDialect(const std::string &dialectName) { mlirConversionTargetAddLegalDialect( target, mlirStringRefCreate(dialectName.data(), dialectName.size())); } void addIllegalDialect(const std::string &dialectName) { mlirConversionTargetAddIllegalDialect( target, mlirStringRefCreate(dialectName.data(), dialectName.size())); } MlirConversionTarget get() { return target; } private: MlirConversionTarget target; }; class PyTypeConverter { public: PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {} PyTypeConverter(MlirTypeConverter typeConverter) : typeConverter(typeConverter), owner(false) {} ~PyTypeConverter() { if (owner) mlirTypeConverterDestroy(typeConverter); } void addConversion(const nb::callable &convert) { mlirTypeConverterAddConversion( typeConverter, [](MlirType type, MlirType *converted, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type)); nb::object res = f(PyType(ctx, type).maybeDownCast()); if (res.is_none()) return mlirLogicalResultFailure(); *converted = nb::cast(res).get(); return mlirLogicalResultSuccess(); }, convert.ptr()); } MlirTypeConverter get() { return typeConverter; } private: MlirTypeConverter typeConverter; bool owner; }; class PyConversionPattern { public: PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {} PyTypeConverter getTypeConverter() { return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern)); } private: MlirConversionPattern pattern; }; #if MLIR_ENABLE_PDL_IN_PATTERNMATCH struct PyMlirPDLResultList : MlirPDLResultList {}; static nb::object objectFromPDLValue(MlirPDLValue value) { if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) return nb::cast(v); if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) return nb::cast(v); if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) return nb::cast(v); if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) return nb::cast(v); throw std::runtime_error("unsupported PDL value type"); } static std::vector objectsFromPDLValues(size_t nValues, MlirPDLValue *values) { std::vector args; args.reserve(nValues); for (size_t i = 0; i < nValues; ++i) args.push_back(objectFromPDLValue(values[i])); return args; } // Convert the Python object to a boolean. // If it evaluates to False, treat it as success; // otherwise, treat it as failure. // Note that None is considered success. static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { if (obj.is_none()) return mlirLogicalResultSuccess(); return nb::cast(obj) ? mlirLogicalResultFailure() : mlirLogicalResultSuccess(); } /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} PyPDLPatternModule(PyPDLPatternModule &&other) noexcept : module(other.module) { other.module.ptr = nullptr; } ~PyPDLPatternModule() { if (module.ptr != nullptr) mlirPDLPatternModuleDestroy(module); } MlirPDLPatternModule get() { return module; } void registerRewriteFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterRewriteFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } void registerConstraintFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterConstraintFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } private: MlirPDLPatternModule module; }; #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH /// Owning Wrapper around a FrozenRewritePatternSet. class PyFrozenRewritePatternSet { public: PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept : set(other.set) { other.set.ptr = nullptr; } ~PyFrozenRewritePatternSet() { if (set.ptr != nullptr) mlirFrozenRewritePatternSetDestroy(set); } MlirFrozenRewritePatternSet get() { return set; } nb::object getCapsule() { return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) throw nb::python_error(); return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: MlirFrozenRewritePatternSet set; }; class PyRewritePatternSet { public: PyRewritePatternSet(MlirContext ctx) : set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {} ~PyRewritePatternSet() { if (set.ptr) mlirRewritePatternSetDestroy(set); } void add(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite) { MlirRewritePatternCallbacks callbacks; callbacks.construct = [](void *userData) { nb::handle(static_cast(userData)).inc_ref(); }; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); mlirRewritePatternSetAdd(set, pattern); } void addConversion(MlirStringRef rootName, unsigned benefit, const nb::callable &matchAndRewrite, PyTypeConverter &typeConverter) { MlirConversionPatternCallbacks callbacks; callbacks.construct = [](void *userData) { nb::handle(static_cast(userData)).inc_ref(); }; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.matchAndRewrite = [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); std::vector operandsVec(operands, operands + nOperands); nb::object adaptorCls = PyGlobals::get() .lookupOpAdaptorClass( unwrap(mlirIdentifierStr(mlirOperationGetName(op)))) .value_or(nb::borrow(nb::type())); nb::object res = f(opView, adaptorCls(operandsVec, opView), PyConversionPattern(pattern).getTypeConverter(), PyConversionPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirConversionPattern pattern = mlirOpConversionPatternCreate( rootName, benefit, ctx, typeConverter.get(), callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); mlirRewritePatternSetAdd(set, mlirConversionPatternAsRewritePattern(pattern)); } PyFrozenRewritePatternSet freeze() { MlirRewritePatternSet s = set; set.ptr = nullptr; return mlirFreezeRewritePattern(s); } private: MlirRewritePatternSet set; MlirContext ctx; }; enum class PyGreedyRewriteStrictness : std::underlying_type_t< MlirGreedyRewriteStrictness> { ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP, EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS, EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS, }; enum class PyGreedySimplifyRegionLevel : std::underlying_type_t< MlirGreedySimplifyRegionLevel> { DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED, NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL, AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE }; /// Owning Wrapper around a GreedyRewriteDriverConfig. class PyGreedyRewriteConfig { public: PyGreedyRewriteConfig() : config(mlirGreedyRewriteDriverConfigCreate().ptr, PyGreedyRewriteConfig::customDeleter) {} PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept : config(std::move(other.config)) {} PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept : config(other.config) {} MlirGreedyRewriteDriverConfig get() { return MlirGreedyRewriteDriverConfig{config.get()}; } void setMaxIterations(int64_t maxIterations) { mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations); } void setMaxNumRewrites(int64_t maxNumRewrites) { mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites); } void setUseTopDownTraversal(bool useTopDownTraversal) { mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(), useTopDownTraversal); } void enableFolding(bool enable) { mlirGreedyRewriteDriverConfigEnableFolding(get(), enable); } void setStrictness(PyGreedyRewriteStrictness strictness) { mlirGreedyRewriteDriverConfigSetStrictness( get(), static_cast(strictness)); } void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( get(), static_cast(level)); } void enableConstantCSE(bool enable) { mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable); } int64_t getMaxIterations() { return mlirGreedyRewriteDriverConfigGetMaxIterations(get()); } int64_t getMaxNumRewrites() { return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get()); } bool getUseTopDownTraversal() { return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get()); } bool isFoldingEnabled() { return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get()); } PyGreedyRewriteStrictness getStrictness() { return static_cast( mlirGreedyRewriteDriverConfigGetStrictness(get())); } PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { return static_cast( mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get())); } bool isConstantCSEEnabled() { return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get()); } private: std::shared_ptr config; static void customDeleter(void *c) { mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c}); } }; enum class PyDialectConversionFoldingMode : std::underlying_type_t< MlirDialectConversionFoldingMode> { Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER, BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS, AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS, }; class PyConversionConfig { public: PyConversionConfig() : config(mlirConversionConfigCreate().ptr, PyConversionConfig::customDeleter) {} MlirConversionConfig get() { return MlirConversionConfig{config.get()}; } void setFoldingMode(PyDialectConversionFoldingMode mode) { mlirConversionConfigSetFoldingMode(get(), MlirDialectConversionFoldingMode(mode)); } PyDialectConversionFoldingMode getFoldingMode() { return PyDialectConversionFoldingMode( mlirConversionConfigGetFoldingMode(get())); } void enableBuildMaterializations(bool enabled) { mlirConversionConfigEnableBuildMaterializations(get(), enabled); } bool isBuildMaterializationsEnabled() { return mlirConversionConfigIsBuildMaterializationsEnabled(get()); } private: std::shared_ptr config; static void customDeleter(void *c) { mlirConversionConfigDestroy(MlirConversionConfig{c}); } }; /// Create the `mlir.rewrite` here. void populateRewriteSubmodule(nb::module_ &m) { // Enum definitions nb::enum_(m, "GreedyRewriteStrictness") .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP) .value("EXISTING_AND_NEW_OPS", PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS) .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS); nb::enum_(m, "GreedySimplifyRegionLevel") .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED) .value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL) .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE); nb::enum_(m, "DialectConversionFoldingMode") .value("NEVER", PyDialectConversionFoldingMode::Never) .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns) .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns); //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- nb::class_(m, "PatternRewriter") .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, "The current insertion point of the PatternRewriter.") .def( "replace_op", [](PyPatternRewriter &self, PyOperationBase &op, PyOperationBase &newOp) { self.replaceOp(op.getOperation(), newOp.getOperation()); }, "Replace an operation with a new operation.", nb::arg("op"), nb::arg("new_op")) .def( "replace_op", [](PyPatternRewriter &self, PyOperationBase &op, const std::vector &values) { std::vector values_(values.size()); std::copy(values.begin(), values.end(), values_.begin()); self.replaceOp(op.getOperation(), values_); }, "Replace an operation with a list of values.", nb::arg("op"), nb::arg("values")) .def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.", nb::arg("op")); //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- nb::class_(m, "RewritePatternSet") .def( "__init__", [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { new (&self) PyRewritePatternSet(context.get()->get()); }, "context"_a = nb::none()) .def( "add", [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, unsigned benefit) { std::string opName; if (root.is_type()) { opName = nb::cast(root.attr("OPERATION_NAME")); } else if (nb::isinstance(root)) { opName = nb::cast(root); } else { throw nb::type_error( "the root argument must be a type or a string"); } self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit, fn); }, "root"_a, "fn"_a, "benefit"_a = 1, // clang-format off nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"), // clang-format on R"( Add a new rewrite pattern on the specified root operation, using the provided callable for matching and rewriting, and assign it the given benefit. Args: root: The root operation to which this pattern applies. This may be either an OpView subclass (e.g., ``arith.AddIOp``) or an operation name string (e.g., ``"arith.addi"``). fn: The callable to use for matching and rewriting, which takes an operation and a pattern rewriter as arguments. The match is considered successful iff the callable returns a value where ``bool(value)`` is ``False`` (e.g. ``None``). If possible, the operation is cast to its corresponding OpView subclass before being passed to the callable. benefit: The benefit of the pattern, defaulting to 1.)") .def( "add_conversion", [](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn, PyTypeConverter &typeConverter, unsigned benefit) { std::string opName = nb::cast(root.attr("OPERATION_NAME")); self.addConversion( mlirStringRefCreate(opName.data(), opName.size()), benefit, fn, typeConverter); }, "root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1, R"( Add a new conversion pattern on the specified root operation, using the provided callable for matching and rewriting, and assign it the given benefit. Args: root: The root operation to which this pattern applies. This may be either an OpView subclass (e.g., ``arith.AddIOp``) or an operation name string (e.g., ``"arith.addi"``). fn: The callable to use for matching and rewriting, which takes an operation, its adaptor, the type converter and a pattern rewriter as arguments. The match is considered successful iff the callable returns a value where ``bool(value)`` is ``False`` (e.g. ``None``). If possible, the operation is cast to its corresponding OpView subclass before being passed to the callable. type_converter: The type converter to convert types in the IR. benefit: The benefit of the pattern, defaulting to 1.)") .def("freeze", &PyRewritePatternSet::freeze, "Freeze the pattern set into a frozen one."); nb::class_( m, "ConversionPatternRewriter"); nb::class_(m, "ConversionTarget") .def( "__init__", [](PyConversionTarget &self, DefaultingPyMlirContext context) { new (&self) PyConversionTarget(context.get()->get()); }, "context"_a = nb::none()) .def( "add_legal_op", [](PyConversionTarget &self, const nb::args &ops) { for (auto op : ops) { std::string opName = nb::cast(op.attr("OPERATION_NAME")); self.addLegalOp(opName); } }, "ops"_a, "Mark the given operations as legal.") .def( "add_illegal_op", [](PyConversionTarget &self, const nb::args &ops) { for (auto op : ops) { std::string opName = nb::cast(op.attr("OPERATION_NAME")); self.addIllegalOp(opName); } }, "ops"_a, "Mark the given operations as illegal.") .def( "add_legal_dialect", [](PyConversionTarget &self, const nb::args &dialects) { for (auto dialect : dialects) { std::string dialectName = nb::cast(dialect.attr("DIALECT_NAMESPACE")); self.addLegalDialect(dialectName); } }, "dialects"_a, "Mark the given dialects as legal.") .def( "add_illegal_dialect", [](PyConversionTarget &self, const nb::args &dialects) { for (auto dialect : dialects) { std::string dialectName = nb::cast(dialect.attr("DIALECT_NAMESPACE")); self.addIllegalDialect(dialectName); } }, "dialects"_a, "Mark the given dialect as illegal."); nb::class_(m, "TypeConverter") .def(nb::init<>(), "Create a new TypeConverter.") .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a, nb::keep_alive<0, 1>(), "Register a type conversion function."); //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") .def("append", [](PyMlirPDLResultList results, const PyValue &value) { mlirPDLResultListPushBackValue(results, value); }) .def("append", [](PyMlirPDLResultList results, const PyOperation &op) { mlirPDLResultListPushBackOperation(results, op); }) .def("append", [](PyMlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }) .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }); nb::class_(m, "PDLModule") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "freeze", [](PyPDLPatternModule &self) { return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) .def( "register_rewrite_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerRewriteFunction(name, fn); }, nb::keep_alive<1, 3>()) .def( "register_constraint_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerConstraintFunction(name, fn); }, nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "GreedyRewriteConfig") .def(nb::init<>(), "Create a greedy rewrite driver config with defaults") .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations, &PyGreedyRewriteConfig::setMaxIterations, "Maximum number of iterations") .def_prop_rw("max_num_rewrites", &PyGreedyRewriteConfig::getMaxNumRewrites, &PyGreedyRewriteConfig::setMaxNumRewrites, "Maximum number of rewrites per iteration") .def_prop_rw("use_top_down_traversal", &PyGreedyRewriteConfig::getUseTopDownTraversal, &PyGreedyRewriteConfig::setUseTopDownTraversal, "Whether to use top-down traversal") .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled, &PyGreedyRewriteConfig::enableFolding, "Enable or disable folding") .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness, &PyGreedyRewriteConfig::setStrictness, "Rewrite strictness level") .def_prop_rw("region_simplification_level", &PyGreedyRewriteConfig::getRegionSimplificationLevel, &PyGreedyRewriteConfig::setRegionSimplificationLevel, "Region simplification level") .def_prop_rw("enable_constant_cse", &PyGreedyRewriteConfig::isConstantCSEEnabled, &PyGreedyRewriteConfig::enableConstantCSE, "Enable or disable constant CSE"); nb::class_(m, "ConversionConfig") .def(nb::init<>(), "Create a conversion config with defaults") .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode, &PyConversionConfig::setFoldingMode, "folding behavior during dialect conversion") .def_prop_rw("build_materializations", &PyConversionConfig::isBuildMaterializationsEnabled, &PyConversionConfig::enableBuildMaterializations, "Whether the dialect conversion attempts to build " "source/target materializations"); nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( "apply_patterns_and_fold_greedily", [](PyModule &module, PyFrozenRewritePatternSet &set, std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily( module.get(), set.get(), config.has_value() ? config->get() : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("pattern application failed to converge"); }, "module"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given module greedily while folding " "results.") .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set, std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp( op.getOperation(), set.get(), config.has_value() ? config->get() : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error( "pattern application failed to converge"); }, "op"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given op greedily while folding " "results.") .def( "walk_and_apply_patterns", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { mlirWalkAndApplyPatterns(op.getOperation(), set.get()); }, "op"_a, "set"_a, "Applies the given patterns to the given op by a fast walk-based " "driver.") .def( "apply_partial_conversion", [](PyOperationBase &op, PyConversionTarget &target, PyFrozenRewritePatternSet &set, std::optional config) { if (!config) config.emplace(PyConversionConfig()); MlirLogicalResult status = mlirApplyPartialConversion( op.getOperation(), target.get(), set.get(), config->get()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("partial conversion failed"); }, "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), "Applies a partial conversion on the given operation.") .def( "apply_full_conversion", [](PyOperationBase &op, PyConversionTarget &target, PyFrozenRewritePatternSet &set, std::optional config) { if (!config) config.emplace(PyConversionConfig()); MlirLogicalResult status = mlirApplyFullConversion( op.getOperation(), target.get(), set.get(), config->get()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("full conversion failed"); }, "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), "Applies a full conversion on the given operation."); } } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir