//===- Rewrite.cpp - Rewrite ----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "Rewrite.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Globals.h" #include "mlir/Bindings/Python/IRCore.h" #include "mlir/Config/mlir-config.h" #include "nanobind/nanobind.h" #include namespace nb = nanobind; using namespace mlir; using namespace nb::literals; using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { // Convert the Python object to a boolean. // If it evaluates to False, treat it as success; // otherwise, treat it as failure. // Note that None is considered success. static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { if (obj.is_none()) return mlirLogicalResultSuccess(); return nb::cast(obj) ? mlirLogicalResultFailure() : mlirLogicalResultSuccess(); } static std::string operationNameFromObject(nb::handle root) { if (root.is_type()) return nb::cast(root.attr("OPERATION_NAME")); if (nb::isinstance(root)) return nb::cast(root); throw nb::type_error("the root argument must be a type or a string"); } static std::string dialectNameFromObject(nb::handle root) { if (root.is_type()) return nb::cast(root.attr("DIALECT_NAMESPACE")); if (nb::isinstance(root)) return nb::cast(root); throw nb::type_error("the root argument must be a type or a string"); } class PyPatternRewriter : public PyRewriterBase { public: static constexpr const char *pyClassName = "PatternRewriter"; PyPatternRewriter(MlirPatternRewriter rewriter) : PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {} }; //===----------------------------------------------------------------------===// // PyRewritePatternSet //===----------------------------------------------------------------------===// PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx) : patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {} PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns) : patterns(patterns), owned(false) {} PyRewritePatternSet::~PyRewritePatternSet() { if (owned && patterns.ptr) mlirRewritePatternSetDestroy(patterns); } MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; } bool PyRewritePatternSet::isOwned() const { return owned; } void PyRewritePatternSet::add(nb::handle root, const nb::callable &matchAndRewrite, unsigned benefit) { std::string opName = operationNameFromObject(root); MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size()); MlirRewritePatternCallbacks callbacks; callbacks.construct = [](void *userData) { nb::handle(static_cast(userData)).inc_ref(); }; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op, MlirPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(context, op)->createOpView(); nb::object res = f(opView, PyPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirRewritePattern pattern = mlirOpRewritePatternCreate( rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); mlirRewritePatternSetAdd(patterns, pattern); } //===----------------------------------------------------------------------===// // PyConversionPatternRewriter //===----------------------------------------------------------------------===// class PyConversionPatternRewriter : public PyPatternRewriter { public: PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter) : PyPatternRewriter( mlirConversionPatternRewriterAsPatternRewriter(rewriter)), rewriter(rewriter) {} MlirConversionPatternRewriter rewriter; }; class PyConversionTarget { public: PyConversionTarget(MlirContext context) : target(mlirConversionTargetCreate(context)) {} ~PyConversionTarget() { mlirConversionTargetDestroy(target); } void addLegalOp(const std::string &opName) { mlirConversionTargetAddLegalOp( target, mlirStringRefCreate(opName.data(), opName.size())); } void addIllegalOp(const std::string &opName) { mlirConversionTargetAddIllegalOp( target, mlirStringRefCreate(opName.data(), opName.size())); } void addLegalDialect(const std::string &dialectName) { mlirConversionTargetAddLegalDialect( target, mlirStringRefCreate(dialectName.data(), dialectName.size())); } void addIllegalDialect(const std::string &dialectName) { mlirConversionTargetAddIllegalDialect( target, mlirStringRefCreate(dialectName.data(), dialectName.size())); } MlirConversionTarget get() { return target; } private: MlirConversionTarget target; }; class PyTypeConverter { public: PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {} PyTypeConverter(MlirTypeConverter typeConverter) : typeConverter(typeConverter), owner(false) {} ~PyTypeConverter() { if (owner) mlirTypeConverterDestroy(typeConverter); } void addConversion(const nb::callable &convert) { mlirTypeConverterAddConversion( typeConverter, [](MlirType type, MlirType *converted, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type)); nb::object res = f(PyType(ctx, type).maybeDownCast()); if (res.is_none()) return mlirLogicalResultFailure(); *converted = nb::cast(res).get(); return mlirLogicalResultSuccess(); }, convert.ptr()); } nb::typed> convertType(PyType &type) { MlirType converted = mlirTypeConverterConvertType(typeConverter, type); if (mlirTypeIsNull(converted)) return nb::none(); return PyType(PyMlirContext::forContext(mlirTypeGetContext(converted)), converted) .maybeDownCast(); } MlirTypeConverter get() { return typeConverter; } private: MlirTypeConverter typeConverter; bool owner; }; class PyConversionPattern { public: PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {} PyTypeConverter getTypeConverter() { return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern)); } private: MlirConversionPattern pattern; }; void PyRewritePatternSet::addConversion(nb::handle root, const nb::callable &matchAndRewrite, PyTypeConverter &typeConverter, unsigned benefit) { std::string opName = operationNameFromObject(root); MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size()); MlirConversionPatternCallbacks callbacks; callbacks.construct = [](void *userData) { nb::handle(static_cast(userData)).inc_ref(); }; callbacks.destruct = [](void *userData) { nb::handle(static_cast(userData)).dec_ref(); }; callbacks.matchAndRewrite = [](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands, MlirValue *operands, MlirConversionPatternRewriter rewriter, void *userData) -> MlirLogicalResult { nb::handle f(static_cast(userData)); PyMlirContextRef ctx = PyMlirContext::forContext(mlirOperationGetContext(op)); nb::object opView = PyOperation::forOperation(ctx, op)->createOpView(); std::vector operandsVec(operands, operands + nOperands); nb::object adaptorCls = PyGlobals::get() .lookupOpAdaptorClass([&] { MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op)); return std::string_view(ref.data, ref.length); }()) .value_or(nb::borrow(nb::type())); nb::object res = f(opView, adaptorCls(operandsVec, opView), PyConversionPattern(pattern).getTypeConverter(), PyConversionPatternRewriter(rewriter)); return logicalResultFromObject(res); }; MlirConversionPattern pattern = mlirOpConversionPatternCreate( rootName, benefit, mlirRewritePatternSetGetContext(patterns), typeConverter.get(), callbacks, matchAndRewrite.ptr(), /* nGeneratedNames */ 0, /* generatedNames */ nullptr); mlirRewritePatternSetAdd(patterns, mlirConversionPatternAsRewritePattern(pattern)); } #if MLIR_ENABLE_PDL_IN_PATTERNMATCH struct PyMlirPDLResultList : MlirPDLResultList {}; static nb::object objectFromPDLValue(MlirPDLValue value) { if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) return nb::cast(v); if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) return nb::cast(v); if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) return nb::cast(v); if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) return nb::cast(v); throw std::runtime_error("unsupported PDL value type"); } static std::vector objectsFromPDLValues(size_t nValues, MlirPDLValue *values) { std::vector args; args.reserve(nValues); for (size_t i = 0; i < nValues; ++i) args.push_back(objectFromPDLValue(values[i])); return args; } /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} PyPDLPatternModule(PyPDLPatternModule &&other) noexcept : module(other.module) { other.module.ptr = nullptr; } ~PyPDLPatternModule() { if (module.ptr != nullptr) mlirPDLPatternModuleDestroy(module); } MlirPDLPatternModule get() { return module; } void registerRewriteFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterRewriteFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } void registerConstraintFunction(const std::string &name, const nb::callable &fn) { mlirPDLPatternModuleRegisterConstraintFunction( get(), mlirStringRefCreate(name.data(), name.size()), [](MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, MlirPDLValue *values, void *userData) -> MlirLogicalResult { nb::handle f = nb::handle(static_cast(userData)); return logicalResultFromObject( f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr}, objectsFromPDLValues(nValues, values))); }, fn.ptr()); } private: MlirPDLPatternModule module; }; #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH /// Owning Wrapper around a FrozenRewritePatternSet. class PyFrozenRewritePatternSet { public: PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept : set(other.set) { other.set.ptr = nullptr; } ~PyFrozenRewritePatternSet() { if (set.ptr != nullptr) mlirFrozenRewritePatternSetDestroy(set); } MlirFrozenRewritePatternSet get() { return set; } nb::object getCapsule() { return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) throw nb::python_error(); return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: MlirFrozenRewritePatternSet set; }; void PyRewritePatternSet::bind(nb::module_ &m) { nb::class_(m, "RewritePatternSet") .def( "__init__", [](PyRewritePatternSet &self, DefaultingPyMlirContext context) { new (&self) PyRewritePatternSet(context.get()->get()); }, "context"_a = nb::none()) .def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"), nb::arg("benefit") = 1, R"(Add a new rewrite pattern on the specified root operation, using the provided callable for matching and rewriting, and assign it the given benefit. Args: root: The root operation to which this pattern applies. This may be either an OpView subclass or an operation name. fn: The callable to use for matching and rewriting, which takes an operation and a pattern rewriter. The match is considered successful iff the callable returns a falsy value. benefit: The benefit of the pattern, defaulting to 1.)") .def("add_conversion", &PyRewritePatternSet::addConversion, nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"), nb::arg("benefit") = 1, R"( Add a new conversion pattern on the specified root operation, using the provided callable for matching and rewriting, and assign it the given benefit. Args: root: The root operation to which this pattern applies. This may be either an OpView subclass or an operation name. fn: The callable to use for matching and rewriting, which takes an operation, its adaptor, the type converter and a pattern rewriter. The match is considered successful iff the callable returns a falsy value. type_converter: The type converter to convert types in the IR. benefit: The benefit of the pattern, defaulting to 1.)") .def( "freeze", [](PyRewritePatternSet &self) { if (!self.isOwned()) throw std::runtime_error( "cannot freeze a non-owning pattern set"); MlirRewritePatternSet s = self.get(); return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s)); }, "Freeze the pattern set into a frozen one."); } enum class PyGreedyRewriteStrictness : std::underlying_type_t< MlirGreedyRewriteStrictness> { ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP, EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS, EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS, }; enum class PyGreedySimplifyRegionLevel : std::underlying_type_t< MlirGreedySimplifyRegionLevel> { DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED, NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL, AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE }; /// Owning Wrapper around a GreedyRewriteDriverConfig. class PyGreedyRewriteConfig { public: PyGreedyRewriteConfig() : config(mlirGreedyRewriteDriverConfigCreate().ptr, PyGreedyRewriteConfig::customDeleter) {} PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept : config(std::move(other.config)) {} PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept : config(other.config) {} MlirGreedyRewriteDriverConfig get() { return MlirGreedyRewriteDriverConfig{config.get()}; } void setMaxIterations(int64_t maxIterations) { mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations); } void setMaxNumRewrites(int64_t maxNumRewrites) { mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites); } void setUseTopDownTraversal(bool useTopDownTraversal) { mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(), useTopDownTraversal); } void enableFolding(bool enable) { mlirGreedyRewriteDriverConfigEnableFolding(get(), enable); } void setStrictness(PyGreedyRewriteStrictness strictness) { mlirGreedyRewriteDriverConfigSetStrictness( get(), static_cast(strictness)); } void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) { mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( get(), static_cast(level)); } void enableConstantCSE(bool enable) { mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable); } int64_t getMaxIterations() { return mlirGreedyRewriteDriverConfigGetMaxIterations(get()); } int64_t getMaxNumRewrites() { return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get()); } bool getUseTopDownTraversal() { return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get()); } bool isFoldingEnabled() { return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get()); } PyGreedyRewriteStrictness getStrictness() { return static_cast( mlirGreedyRewriteDriverConfigGetStrictness(get())); } PyGreedySimplifyRegionLevel getRegionSimplificationLevel() { return static_cast( mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get())); } bool isConstantCSEEnabled() { return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get()); } private: std::shared_ptr config; static void customDeleter(void *c) { mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c}); } }; enum class PyDialectConversionFoldingMode : std::underlying_type_t< MlirDialectConversionFoldingMode> { Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER, BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS, AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS, }; class PyConversionConfig { public: PyConversionConfig() : config(mlirConversionConfigCreate().ptr, PyConversionConfig::customDeleter) {} MlirConversionConfig get() { return MlirConversionConfig{config.get()}; } void setFoldingMode(PyDialectConversionFoldingMode mode) { mlirConversionConfigSetFoldingMode(get(), MlirDialectConversionFoldingMode(mode)); } PyDialectConversionFoldingMode getFoldingMode() { return PyDialectConversionFoldingMode( mlirConversionConfigGetFoldingMode(get())); } void enableBuildMaterializations(bool enabled) { mlirConversionConfigEnableBuildMaterializations(get(), enabled); } bool isBuildMaterializationsEnabled() { return mlirConversionConfigIsBuildMaterializationsEnabled(get()); } private: std::shared_ptr config; static void customDeleter(void *c) { mlirConversionConfigDestroy(MlirConversionConfig{c}); } }; /// Create the `mlir.rewrite` here. void populateRewriteSubmodule(nb::module_ &m) { // Enum definitions nb::enum_(m, "GreedyRewriteStrictness") .value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP) .value("EXISTING_AND_NEW_OPS", PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS) .value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS); nb::enum_(m, "GreedySimplifyRegionLevel") .value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED) .value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL) .value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE); nb::enum_(m, "DialectConversionFoldingMode") .value("NEVER", PyDialectConversionFoldingMode::Never) .value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns) .value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns); //---------------------------------------------------------------------------- // Mapping of the PatternRewriter //---------------------------------------------------------------------------- PyPatternRewriter::bind(m); //---------------------------------------------------------------------------- // Mapping of the RewritePatternSet //---------------------------------------------------------------------------- PyRewritePatternSet::bind(m); nb::class_( m, "ConversionPatternRewriter") .def("convert_region_types", [](PyConversionPatternRewriter &self, PyRegion ®ion, PyTypeConverter &typeConverter) { mlirConversionPatternRewriterConvertRegionTypes( self.rewriter, region.get(), typeConverter.get()); }); nb::class_(m, "ConversionTarget") .def( "__init__", [](PyConversionTarget &self, DefaultingPyMlirContext context) { new (&self) PyConversionTarget(context.get()->get()); }, "context"_a = nb::none()) .def( "add_legal_op", [](PyConversionTarget &self, const nb::args &ops) { for (auto op : ops) { self.addLegalOp(operationNameFromObject(op)); } }, "ops"_a, "Mark the given operations as legal.") .def( "add_illegal_op", [](PyConversionTarget &self, const nb::args &ops) { for (auto op : ops) { self.addIllegalOp(operationNameFromObject(op)); } }, "ops"_a, "Mark the given operations as illegal.") .def( "add_legal_dialect", [](PyConversionTarget &self, const nb::args &dialects) { for (auto dialect : dialects) { self.addLegalDialect(dialectNameFromObject(dialect)); } }, "dialects"_a, "Mark the given dialects as legal.") .def( "add_illegal_dialect", [](PyConversionTarget &self, const nb::args &dialects) { for (auto dialect : dialects) { self.addIllegalDialect(dialectNameFromObject(dialect)); } }, "dialects"_a, "Mark the given dialect as illegal."); nb::class_(m, "TypeConverter") .def(nb::init<>(), "Create a new TypeConverter.") .def("add_conversion", &PyTypeConverter::addConversion, "convert"_a, nb::keep_alive<0, 1>(), "Register a type conversion function.") .def("convert_type", &PyTypeConverter::convertType, "type"_a, "Convert the given type. Returns None if conversion fails."); //---------------------------------------------------------------------------- // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "PDLResultList") .def("append", [](PyMlirPDLResultList results, const PyValue &value) { mlirPDLResultListPushBackValue(results, value); }) .def("append", [](PyMlirPDLResultList results, const PyOperation &op) { mlirPDLResultListPushBackOperation(results, op); }) .def("append", [](PyMlirPDLResultList results, const PyType &type) { mlirPDLResultListPushBackType(results, type); }) .def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) { mlirPDLResultListPushBackAttribute(results, attr); }); nb::class_(m, "PDLModule") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "__init__", [](PyPDLPatternModule &self, PyModule &module) { new (&self) PyPDLPatternModule( mlirPDLPatternModuleFromModule(module.get())); }, "module"_a, "Create a PDL module from the given module.") .def( "freeze", [](PyPDLPatternModule &self) { return PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }, nb::keep_alive<0, 1>()) .def( "register_rewrite_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerRewriteFunction(name, fn); }, nb::keep_alive<1, 3>()) .def( "register_constraint_function", [](PyPDLPatternModule &self, const std::string &name, const nb::callable &fn) { self.registerConstraintFunction(name, fn); }, nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "GreedyRewriteConfig") .def(nb::init<>(), "Create a greedy rewrite driver config with defaults") .def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations, &PyGreedyRewriteConfig::setMaxIterations, "Maximum number of iterations") .def_prop_rw("max_num_rewrites", &PyGreedyRewriteConfig::getMaxNumRewrites, &PyGreedyRewriteConfig::setMaxNumRewrites, "Maximum number of rewrites per iteration") .def_prop_rw("use_top_down_traversal", &PyGreedyRewriteConfig::getUseTopDownTraversal, &PyGreedyRewriteConfig::setUseTopDownTraversal, "Whether to use top-down traversal") .def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled, &PyGreedyRewriteConfig::enableFolding, "Enable or disable folding") .def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness, &PyGreedyRewriteConfig::setStrictness, "Rewrite strictness level") .def_prop_rw("region_simplification_level", &PyGreedyRewriteConfig::getRegionSimplificationLevel, &PyGreedyRewriteConfig::setRegionSimplificationLevel, "Region simplification level") .def_prop_rw("enable_constant_cse", &PyGreedyRewriteConfig::isConstantCSEEnabled, &PyGreedyRewriteConfig::enableConstantCSE, "Enable or disable constant CSE"); nb::class_(m, "ConversionConfig") .def(nb::init<>(), "Create a conversion config with defaults") .def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode, &PyConversionConfig::setFoldingMode, "folding behavior during dialect conversion") .def_prop_rw("build_materializations", &PyConversionConfig::isBuildMaterializationsEnabled, &PyConversionConfig::enableBuildMaterializations, "Whether the dialect conversion attempts to build " "source/target materializations"); nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( "apply_patterns_and_fold_greedily", [](PyModule &module, PyFrozenRewritePatternSet &set, std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily( module.get(), set.get(), config.has_value() ? config->get() : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error("pattern application failed to converge"); }, "module"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given module greedily while folding " "results.") .def( "apply_patterns_and_fold_greedily", [](PyOperationBase &op, PyFrozenRewritePatternSet &set, std::optional config) { MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp( op.getOperation(), set.get(), config.has_value() ? config->get() : mlirGreedyRewriteDriverConfigCreate()); if (mlirLogicalResultIsFailure(status)) throw std::runtime_error( "pattern application failed to converge"); }, "op"_a, "set"_a, "config"_a = nb::none(), "Applys the given patterns to the given op greedily while folding " "results.") .def( "walk_and_apply_patterns", [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { mlirWalkAndApplyPatterns(op.getOperation(), set.get()); }, "op"_a, "set"_a, "Applies the given patterns to the given op by a fast walk-based " "driver.") .def( "apply_partial_conversion", [](PyOperationBase &op, PyConversionTarget &target, PyFrozenRewritePatternSet &set, std::optional config) { if (!config) config.emplace(PyConversionConfig()); PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirApplyPartialConversion( op.getOperation(), target.get(), set.get(), config->get()); if (mlirLogicalResultIsFailure(status)) throw MLIRError("partial conversion failed", errors.take()); }, "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), "Applies a partial conversion on the given operation.") .def( "apply_full_conversion", [](PyOperationBase &op, PyConversionTarget &target, PyFrozenRewritePatternSet &set, std::optional config) { if (!config) config.emplace(PyConversionConfig()); PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirApplyFullConversion( op.getOperation(), target.get(), set.get(), config->get()); if (mlirLogicalResultIsFailure(status)) throw MLIRError("full conversion failed", errors.take()); }, "op"_a, "target"_a, "set"_a, "config"_a = nb::none(), "Applies a full conversion on the given operation."); } } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir