Files
Twice f0142833b6 [MLIR][Python] Move operation/dialect name retrieving as a util function (#184605)
We have a common pattern that retrieve an operation name or dialect name
from a `type` or `str` in the rewrite nanobind module, so better to make
it a common util function.

---------

Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
2026-03-05 10:37:12 +08:00

820 lines
32 KiB
C++

//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Rewrite.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
#include <type_traits>
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
// Note that None is considered success.
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
if (obj.is_none())
return mlirLogicalResultSuccess();
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
: mlirLogicalResultSuccess();
}
static std::string operationNameFromObject(nb::handle root) {
if (root.is_type())
return nb::cast<std::string>(root.attr("OPERATION_NAME"));
if (nb::isinstance<nb::str>(root))
return nb::cast<std::string>(root);
throw nb::type_error("the root argument must be a type or a string");
}
static std::string dialectNameFromObject(nb::handle root) {
if (root.is_type())
return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
if (nb::isinstance<nb::str>(root))
return nb::cast<std::string>(root);
throw nb::type_error("the root argument must be a type or a string");
}
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
static constexpr const char *pyClassName = "PatternRewriter";
PyPatternRewriter(MlirPatternRewriter rewriter)
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
//===----------------------------------------------------------------------===//
// PyRewritePatternSet
//===----------------------------------------------------------------------===//
PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
: patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
: patterns(patterns), owned(false) {}
PyRewritePatternSet::~PyRewritePatternSet() {
if (owned && patterns.ptr)
mlirRewritePatternSetDestroy(patterns);
}
MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
bool PyRewritePatternSet::isOwned() const { return owned; }
void PyRewritePatternSet::add(nb::handle root,
const nb::callable &matchAndRewrite,
unsigned benefit) {
std::string opName = operationNameFromObject(root);
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(context, op)->createOpView();
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks,
matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(patterns, pattern);
}
//===----------------------------------------------------------------------===//
// PyConversionPatternRewriter
//===----------------------------------------------------------------------===//
class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
: PyPatternRewriter(
mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
rewriter(rewriter) {}
MlirConversionPatternRewriter rewriter;
};
class PyConversionTarget {
public:
PyConversionTarget(MlirContext context)
: target(mlirConversionTargetCreate(context)) {}
~PyConversionTarget() { mlirConversionTargetDestroy(target); }
void addLegalOp(const std::string &opName) {
mlirConversionTargetAddLegalOp(
target, mlirStringRefCreate(opName.data(), opName.size()));
}
void addIllegalOp(const std::string &opName) {
mlirConversionTargetAddIllegalOp(
target, mlirStringRefCreate(opName.data(), opName.size()));
}
void addLegalDialect(const std::string &dialectName) {
mlirConversionTargetAddLegalDialect(
target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
}
void addIllegalDialect(const std::string &dialectName) {
mlirConversionTargetAddIllegalDialect(
target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
}
MlirConversionTarget get() { return target; }
private:
MlirConversionTarget target;
};
class PyTypeConverter {
public:
PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
PyTypeConverter(MlirTypeConverter typeConverter)
: typeConverter(typeConverter), owner(false) {}
~PyTypeConverter() {
if (owner)
mlirTypeConverterDestroy(typeConverter);
}
void addConversion(const nb::callable &convert) {
mlirTypeConverterAddConversion(
typeConverter,
[](MlirType type, MlirType *converted,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type));
nb::object res = f(PyType(ctx, type).maybeDownCast());
if (res.is_none())
return mlirLogicalResultFailure();
*converted = nb::cast<PyType>(res).get();
return mlirLogicalResultSuccess();
},
convert.ptr());
}
nb::typed<nb::object, std::optional<PyType>> convertType(PyType &type) {
MlirType converted = mlirTypeConverterConvertType(typeConverter, type);
if (mlirTypeIsNull(converted))
return nb::none();
return PyType(PyMlirContext::forContext(mlirTypeGetContext(converted)),
converted)
.maybeDownCast();
}
MlirTypeConverter get() { return typeConverter; }
private:
MlirTypeConverter typeConverter;
bool owner;
};
class PyConversionPattern {
public:
PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
PyTypeConverter getTypeConverter() {
return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern));
}
private:
MlirConversionPattern pattern;
};
void PyRewritePatternSet::addConversion(nb::handle root,
const nb::callable &matchAndRewrite,
PyTypeConverter &typeConverter,
unsigned benefit) {
std::string opName = operationNameFromObject(root);
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirConversionPatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite =
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
MlirValue *operands, MlirConversionPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
nb::object adaptorCls =
PyGlobals::get()
.lookupOpAdaptorClass([&] {
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
return std::string_view(ref.data, ref.length);
}())
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
nb::object res = f(opView, adaptorCls(operandsVec, opView),
PyConversionPattern(pattern).getTypeConverter(),
PyConversionPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
rootName, benefit, mlirRewritePatternSetGetContext(patterns),
typeConverter.get(), callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(patterns,
mlirConversionPatternAsRewritePattern(pattern));
}
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
struct PyMlirPDLResultList : MlirPDLResultList {};
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
return nb::cast(v);
if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
return nb::cast(v);
if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
return nb::cast(v);
if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
return nb::cast(v);
throw std::runtime_error("unsupported PDL value type");
}
static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
MlirPDLValue *values) {
std::vector<nb::object> args;
args.reserve(nValues);
for (size_t i = 0; i < nValues; ++i)
args.push_back(objectFromPDLValue(values[i]));
return args;
}
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
: module(other.module) {
other.module.ptr = nullptr;
}
~PyPDLPatternModule() {
if (module.ptr != nullptr)
mlirPDLPatternModuleDestroy(module);
}
MlirPDLPatternModule get() { return module; }
void registerRewriteFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterRewriteFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
void registerConstraintFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterConstraintFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
private:
MlirPDLPatternModule module;
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a FrozenRewritePatternSet.
class PyFrozenRewritePatternSet {
public:
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
: set(other.set) {
other.set.ptr = nullptr;
}
~PyFrozenRewritePatternSet() {
if (set.ptr != nullptr)
mlirFrozenRewritePatternSetDestroy(set);
}
MlirFrozenRewritePatternSet get() { return set; }
nb::object getCapsule() {
return nb::steal<nb::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
static nb::object createFromCapsule(const nb::object &capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw nb::python_error();
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
}
private:
MlirFrozenRewritePatternSet set;
};
void PyRewritePatternSet::bind(nb::module_ &m) {
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
.def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
nb::arg("benefit") = 1,
R"(Add a new rewrite pattern on the specified root operation, using
the provided callable for matching and rewriting, and assign it
the given benefit.
Args:
root: The root operation to which this pattern applies. This may
be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes
an operation and a pattern rewriter. The match is considered
successful iff the callable returns a falsy value.
benefit: The benefit of the pattern, defaulting to 1.)")
.def("add_conversion", &PyRewritePatternSet::addConversion,
nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"),
nb::arg("benefit") = 1,
R"(
Add a new conversion pattern on the specified root operation,
using the provided callable for matching and rewriting,
and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes an
operation, its adaptor, the type converter and a pattern
rewriter. The match is considered successful iff the callable
returns a falsy value.
type_converter: The type converter to convert types in the IR.
benefit: The benefit of the pattern, defaulting to 1.)")
.def(
"freeze",
[](PyRewritePatternSet &self) {
if (!self.isOwned())
throw std::runtime_error(
"cannot freeze a non-owning pattern set");
MlirRewritePatternSet s = self.get();
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
},
"Freeze the pattern set into a frozen one.");
}
enum class PyGreedyRewriteStrictness : std::underlying_type_t<
MlirGreedyRewriteStrictness> {
ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS,
};
enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
MlirGreedySimplifyRegionLevel> {
DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
};
/// Owning Wrapper around a GreedyRewriteDriverConfig.
class PyGreedyRewriteConfig {
public:
PyGreedyRewriteConfig()
: config(mlirGreedyRewriteDriverConfigCreate().ptr,
PyGreedyRewriteConfig::customDeleter) {}
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
: config(std::move(other.config)) {}
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
: config(other.config) {}
MlirGreedyRewriteDriverConfig get() {
return MlirGreedyRewriteDriverConfig{config.get()};
}
void setMaxIterations(int64_t maxIterations) {
mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
}
void setMaxNumRewrites(int64_t maxNumRewrites) {
mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
}
void setUseTopDownTraversal(bool useTopDownTraversal) {
mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
useTopDownTraversal);
}
void enableFolding(bool enable) {
mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
}
void setStrictness(PyGreedyRewriteStrictness strictness) {
mlirGreedyRewriteDriverConfigSetStrictness(
get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
}
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
}
void enableConstantCSE(bool enable) {
mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
}
int64_t getMaxIterations() {
return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
}
int64_t getMaxNumRewrites() {
return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
}
bool getUseTopDownTraversal() {
return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
}
bool isFoldingEnabled() {
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
}
PyGreedyRewriteStrictness getStrictness() {
return static_cast<PyGreedyRewriteStrictness>(
mlirGreedyRewriteDriverConfigGetStrictness(get()));
}
PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
return static_cast<PyGreedySimplifyRegionLevel>(
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
}
bool isConstantCSEEnabled() {
return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
}
private:
std::shared_ptr<void> config;
static void customDeleter(void *c) {
mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
}
};
enum class PyDialectConversionFoldingMode : std::underlying_type_t<
MlirDialectConversionFoldingMode> {
Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
};
class PyConversionConfig {
public:
PyConversionConfig()
: config(mlirConversionConfigCreate().ptr,
PyConversionConfig::customDeleter) {}
MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
void setFoldingMode(PyDialectConversionFoldingMode mode) {
mlirConversionConfigSetFoldingMode(get(),
MlirDialectConversionFoldingMode(mode));
}
PyDialectConversionFoldingMode getFoldingMode() {
return PyDialectConversionFoldingMode(
mlirConversionConfigGetFoldingMode(get()));
}
void enableBuildMaterializations(bool enabled) {
mlirConversionConfigEnableBuildMaterializations(get(), enabled);
}
bool isBuildMaterializationsEnabled() {
return mlirConversionConfigIsBuildMaterializationsEnabled(get());
}
private:
std::shared_ptr<void> config;
static void customDeleter(void *c) {
mlirConversionConfigDestroy(MlirConversionConfig{c});
}
};
/// Create the `mlir.rewrite` here.
void populateRewriteSubmodule(nb::module_ &m) {
// Enum definitions
nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
.value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
.value("EXISTING_AND_NEW_OPS",
PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS)
.value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
.value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
.value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
.value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
.value("NEVER", PyDialectConversionFoldingMode::Never)
.value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
.value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
PyPatternRewriter::bind(m);
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
PyRewritePatternSet::bind(m);
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
m, "ConversionPatternRewriter")
.def("convert_region_types",
[](PyConversionPatternRewriter &self, PyRegion &region,
PyTypeConverter &typeConverter) {
mlirConversionPatternRewriterConvertRegionTypes(
self.rewriter, region.get(), typeConverter.get());
});
nb::class_<PyConversionTarget>(m, "ConversionTarget")
.def(
"__init__",
[](PyConversionTarget &self, DefaultingPyMlirContext context) {
new (&self) PyConversionTarget(context.get()->get());
},
"context"_a = nb::none())
.def(
"add_legal_op",
[](PyConversionTarget &self, const nb::args &ops) {
for (auto op : ops) {
self.addLegalOp(operationNameFromObject(op));
}
},
"ops"_a, "Mark the given operations as legal.")
.def(
"add_illegal_op",
[](PyConversionTarget &self, const nb::args &ops) {
for (auto op : ops) {
self.addIllegalOp(operationNameFromObject(op));
}
},
"ops"_a, "Mark the given operations as illegal.")
.def(
"add_legal_dialect",
[](PyConversionTarget &self, const nb::args &dialects) {
for (auto dialect : dialects) {
self.addLegalDialect(dialectNameFromObject(dialect));
}
},
"dialects"_a, "Mark the given dialects as legal.")
.def(
"add_illegal_dialect",
[](PyConversionTarget &self, const nb::args &dialects) {
for (auto dialect : dialects) {
self.addIllegalDialect(dialectNameFromObject(dialect));
}
},
"dialects"_a, "Mark the given dialect as illegal.");
nb::class_<PyTypeConverter>(m, "TypeConverter")
.def(nb::init<>(), "Create a new TypeConverter.")
.def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
nb::keep_alive<0, 1>(), "Register a type conversion function.")
.def("convert_type", &PyTypeConverter::convertType, "type"_a,
"Convert the given type. Returns None if conversion fails.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
.def("append",
[](PyMlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
})
.def("append",
[](PyMlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
})
.def("append",
[](PyMlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
})
.def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
});
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
"module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
"module"_a, "Create a PDL module from the given module.")
.def(
"freeze",
[](PyPDLPatternModule &self) {
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
.def(
"register_rewrite_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
nb::keep_alive<1, 3>())
.def(
"register_constraint_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerConstraintFunction(name, fn);
},
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
.def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
.def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
&PyGreedyRewriteConfig::setMaxIterations,
"Maximum number of iterations")
.def_prop_rw("max_num_rewrites",
&PyGreedyRewriteConfig::getMaxNumRewrites,
&PyGreedyRewriteConfig::setMaxNumRewrites,
"Maximum number of rewrites per iteration")
.def_prop_rw("use_top_down_traversal",
&PyGreedyRewriteConfig::getUseTopDownTraversal,
&PyGreedyRewriteConfig::setUseTopDownTraversal,
"Whether to use top-down traversal")
.def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
&PyGreedyRewriteConfig::enableFolding,
"Enable or disable folding")
.def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
&PyGreedyRewriteConfig::setStrictness,
"Rewrite strictness level")
.def_prop_rw("region_simplification_level",
&PyGreedyRewriteConfig::getRegionSimplificationLevel,
&PyGreedyRewriteConfig::setRegionSimplificationLevel,
"Region simplification level")
.def_prop_rw("enable_constant_cse",
&PyGreedyRewriteConfig::isConstantCSEEnabled,
&PyGreedyRewriteConfig::enableConstantCSE,
"Enable or disable constant CSE");
nb::class_<PyConversionConfig>(m, "ConversionConfig")
.def(nb::init<>(), "Create a conversion config with defaults")
.def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
&PyConversionConfig::setFoldingMode,
"folding behavior during dialect conversion")
.def_prop_rw("build_materializations",
&PyConversionConfig::isBuildMaterializationsEnabled,
&PyConversionConfig::enableBuildMaterializations,
"Whether the dialect conversion attempts to build "
"source/target materializations");
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(),
config.has_value() ? config->get()
: mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set,
std::optional<PyGreedyRewriteConfig> config) {
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
config.has_value() ? config->get()
: mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a, "config"_a = nb::none(),
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
"walk_and_apply_patterns",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
mlirWalkAndApplyPatterns(op.getOperation(), set.get());
},
"op"_a, "set"_a,
"Applies the given patterns to the given op by a fast walk-based "
"driver.")
.def(
"apply_partial_conversion",
[](PyOperationBase &op, PyConversionTarget &target,
PyFrozenRewritePatternSet &set,
std::optional<PyConversionConfig> config) {
if (!config)
config.emplace(PyConversionConfig());
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirApplyPartialConversion(
op.getOperation(), target.get(), set.get(), config->get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("partial conversion failed", errors.take());
},
"op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
"Applies a partial conversion on the given operation.")
.def(
"apply_full_conversion",
[](PyOperationBase &op, PyConversionTarget &target,
PyFrozenRewritePatternSet &set,
std::optional<PyConversionConfig> config) {
if (!config)
config.emplace(PyConversionConfig());
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirApplyFullConversion(
op.getOperation(), target.get(), set.get(), config->get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("full conversion failed", errors.take());
},
"op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
"Applies a full conversion on the given operation.");
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir