We have a common pattern that retrieve an operation name or dialect name from a `type` or `str` in the rewrite nanobind module, so better to make it a common util function. --------- Co-authored-by: Rolf Morel <rolfmorel@gmail.com>
820 lines
32 KiB
C++
820 lines
32 KiB
C++
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Rewrite.h"
|
|
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir-c/Rewrite.h"
|
|
#include "mlir-c/Support.h"
|
|
#include "mlir/Bindings/Python/Globals.h"
|
|
#include "mlir/Bindings/Python/IRCore.h"
|
|
#include "mlir/Config/mlir-config.h"
|
|
#include "nanobind/nanobind.h"
|
|
#include <type_traits>
|
|
|
|
namespace nb = nanobind;
|
|
using namespace mlir;
|
|
using namespace nb::literals;
|
|
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
|
|
|
|
namespace mlir {
|
|
namespace python {
|
|
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
|
|
|
// Convert the Python object to a boolean.
|
|
// If it evaluates to False, treat it as success;
|
|
// otherwise, treat it as failure.
|
|
// Note that None is considered success.
|
|
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
|
|
if (obj.is_none())
|
|
return mlirLogicalResultSuccess();
|
|
|
|
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
|
|
: mlirLogicalResultSuccess();
|
|
}
|
|
|
|
static std::string operationNameFromObject(nb::handle root) {
|
|
if (root.is_type())
|
|
return nb::cast<std::string>(root.attr("OPERATION_NAME"));
|
|
if (nb::isinstance<nb::str>(root))
|
|
return nb::cast<std::string>(root);
|
|
|
|
throw nb::type_error("the root argument must be a type or a string");
|
|
}
|
|
|
|
static std::string dialectNameFromObject(nb::handle root) {
|
|
if (root.is_type())
|
|
return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
|
|
if (nb::isinstance<nb::str>(root))
|
|
return nb::cast<std::string>(root);
|
|
|
|
throw nb::type_error("the root argument must be a type or a string");
|
|
}
|
|
|
|
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
|
|
public:
|
|
static constexpr const char *pyClassName = "PatternRewriter";
|
|
|
|
PyPatternRewriter(MlirPatternRewriter rewriter)
|
|
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PyRewritePatternSet
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
|
|
: patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
|
|
|
|
PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
|
|
: patterns(patterns), owned(false) {}
|
|
|
|
PyRewritePatternSet::~PyRewritePatternSet() {
|
|
if (owned && patterns.ptr)
|
|
mlirRewritePatternSetDestroy(patterns);
|
|
}
|
|
|
|
MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
|
|
|
|
bool PyRewritePatternSet::isOwned() const { return owned; }
|
|
|
|
void PyRewritePatternSet::add(nb::handle root,
|
|
const nb::callable &matchAndRewrite,
|
|
unsigned benefit) {
|
|
std::string opName = operationNameFromObject(root);
|
|
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
|
|
|
|
MlirRewritePatternCallbacks callbacks;
|
|
callbacks.construct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
|
};
|
|
callbacks.destruct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
|
};
|
|
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
|
|
MlirPatternRewriter rewriter,
|
|
void *userData) -> MlirLogicalResult {
|
|
nb::handle f(static_cast<PyObject *>(userData));
|
|
|
|
PyMlirContextRef context =
|
|
PyMlirContext::forContext(mlirOperationGetContext(op));
|
|
nb::object opView = PyOperation::forOperation(context, op)->createOpView();
|
|
|
|
nb::object res = f(opView, PyPatternRewriter(rewriter));
|
|
return logicalResultFromObject(res);
|
|
};
|
|
|
|
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
|
|
rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks,
|
|
matchAndRewrite.ptr(),
|
|
/* nGeneratedNames */ 0,
|
|
/* generatedNames */ nullptr);
|
|
mlirRewritePatternSetAdd(patterns, pattern);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PyConversionPatternRewriter
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class PyConversionPatternRewriter : public PyPatternRewriter {
|
|
public:
|
|
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
|
|
: PyPatternRewriter(
|
|
mlirConversionPatternRewriterAsPatternRewriter(rewriter)),
|
|
rewriter(rewriter) {}
|
|
|
|
MlirConversionPatternRewriter rewriter;
|
|
};
|
|
|
|
class PyConversionTarget {
|
|
public:
|
|
PyConversionTarget(MlirContext context)
|
|
: target(mlirConversionTargetCreate(context)) {}
|
|
~PyConversionTarget() { mlirConversionTargetDestroy(target); }
|
|
|
|
void addLegalOp(const std::string &opName) {
|
|
mlirConversionTargetAddLegalOp(
|
|
target, mlirStringRefCreate(opName.data(), opName.size()));
|
|
}
|
|
|
|
void addIllegalOp(const std::string &opName) {
|
|
mlirConversionTargetAddIllegalOp(
|
|
target, mlirStringRefCreate(opName.data(), opName.size()));
|
|
}
|
|
|
|
void addLegalDialect(const std::string &dialectName) {
|
|
mlirConversionTargetAddLegalDialect(
|
|
target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
|
|
}
|
|
|
|
void addIllegalDialect(const std::string &dialectName) {
|
|
mlirConversionTargetAddIllegalDialect(
|
|
target, mlirStringRefCreate(dialectName.data(), dialectName.size()));
|
|
}
|
|
|
|
MlirConversionTarget get() { return target; }
|
|
|
|
private:
|
|
MlirConversionTarget target;
|
|
};
|
|
|
|
class PyTypeConverter {
|
|
public:
|
|
PyTypeConverter() : typeConverter(mlirTypeConverterCreate()), owner(true) {}
|
|
PyTypeConverter(MlirTypeConverter typeConverter)
|
|
: typeConverter(typeConverter), owner(false) {}
|
|
~PyTypeConverter() {
|
|
if (owner)
|
|
mlirTypeConverterDestroy(typeConverter);
|
|
}
|
|
|
|
void addConversion(const nb::callable &convert) {
|
|
mlirTypeConverterAddConversion(
|
|
typeConverter,
|
|
[](MlirType type, MlirType *converted,
|
|
void *userData) -> MlirLogicalResult {
|
|
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
|
|
auto ctx = PyMlirContext::forContext(mlirTypeGetContext(type));
|
|
nb::object res = f(PyType(ctx, type).maybeDownCast());
|
|
if (res.is_none())
|
|
return mlirLogicalResultFailure();
|
|
|
|
*converted = nb::cast<PyType>(res).get();
|
|
return mlirLogicalResultSuccess();
|
|
},
|
|
convert.ptr());
|
|
}
|
|
|
|
nb::typed<nb::object, std::optional<PyType>> convertType(PyType &type) {
|
|
MlirType converted = mlirTypeConverterConvertType(typeConverter, type);
|
|
if (mlirTypeIsNull(converted))
|
|
return nb::none();
|
|
return PyType(PyMlirContext::forContext(mlirTypeGetContext(converted)),
|
|
converted)
|
|
.maybeDownCast();
|
|
}
|
|
|
|
MlirTypeConverter get() { return typeConverter; }
|
|
|
|
private:
|
|
MlirTypeConverter typeConverter;
|
|
bool owner;
|
|
};
|
|
|
|
class PyConversionPattern {
|
|
public:
|
|
PyConversionPattern(MlirConversionPattern pattern) : pattern(pattern) {}
|
|
|
|
PyTypeConverter getTypeConverter() {
|
|
return PyTypeConverter(mlirConversionPatternGetTypeConverter(pattern));
|
|
}
|
|
|
|
private:
|
|
MlirConversionPattern pattern;
|
|
};
|
|
|
|
void PyRewritePatternSet::addConversion(nb::handle root,
|
|
const nb::callable &matchAndRewrite,
|
|
PyTypeConverter &typeConverter,
|
|
unsigned benefit) {
|
|
std::string opName = operationNameFromObject(root);
|
|
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
|
|
|
|
MlirConversionPatternCallbacks callbacks;
|
|
callbacks.construct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
|
};
|
|
callbacks.destruct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
|
};
|
|
callbacks.matchAndRewrite =
|
|
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
|
|
MlirValue *operands, MlirConversionPatternRewriter rewriter,
|
|
void *userData) -> MlirLogicalResult {
|
|
nb::handle f(static_cast<PyObject *>(userData));
|
|
|
|
PyMlirContextRef ctx =
|
|
PyMlirContext::forContext(mlirOperationGetContext(op));
|
|
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
|
|
|
|
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
|
|
nb::object adaptorCls =
|
|
PyGlobals::get()
|
|
.lookupOpAdaptorClass([&] {
|
|
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
|
|
return std::string_view(ref.data, ref.length);
|
|
}())
|
|
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
|
|
|
|
nb::object res = f(opView, adaptorCls(operandsVec, opView),
|
|
PyConversionPattern(pattern).getTypeConverter(),
|
|
PyConversionPatternRewriter(rewriter));
|
|
return logicalResultFromObject(res);
|
|
};
|
|
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
|
|
rootName, benefit, mlirRewritePatternSetGetContext(patterns),
|
|
typeConverter.get(), callbacks, matchAndRewrite.ptr(),
|
|
/* nGeneratedNames */ 0,
|
|
/* generatedNames */ nullptr);
|
|
mlirRewritePatternSetAdd(patterns,
|
|
mlirConversionPatternAsRewritePattern(pattern));
|
|
}
|
|
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
struct PyMlirPDLResultList : MlirPDLResultList {};
|
|
|
|
static nb::object objectFromPDLValue(MlirPDLValue value) {
|
|
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
|
|
return nb::cast(v);
|
|
if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
|
|
return nb::cast(v);
|
|
if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
|
|
return nb::cast(v);
|
|
if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
|
|
return nb::cast(v);
|
|
|
|
throw std::runtime_error("unsupported PDL value type");
|
|
}
|
|
|
|
static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
|
|
MlirPDLValue *values) {
|
|
std::vector<nb::object> args;
|
|
args.reserve(nValues);
|
|
for (size_t i = 0; i < nValues; ++i)
|
|
args.push_back(objectFromPDLValue(values[i]));
|
|
return args;
|
|
}
|
|
|
|
/// Owning Wrapper around a PDLPatternModule.
|
|
class PyPDLPatternModule {
|
|
public:
|
|
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
|
|
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
|
|
: module(other.module) {
|
|
other.module.ptr = nullptr;
|
|
}
|
|
~PyPDLPatternModule() {
|
|
if (module.ptr != nullptr)
|
|
mlirPDLPatternModuleDestroy(module);
|
|
}
|
|
MlirPDLPatternModule get() { return module; }
|
|
|
|
void registerRewriteFunction(const std::string &name,
|
|
const nb::callable &fn) {
|
|
mlirPDLPatternModuleRegisterRewriteFunction(
|
|
get(), mlirStringRefCreate(name.data(), name.size()),
|
|
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
|
|
size_t nValues, MlirPDLValue *values,
|
|
void *userData) -> MlirLogicalResult {
|
|
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
|
|
return logicalResultFromObject(
|
|
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
|
|
objectsFromPDLValues(nValues, values)));
|
|
},
|
|
fn.ptr());
|
|
}
|
|
|
|
void registerConstraintFunction(const std::string &name,
|
|
const nb::callable &fn) {
|
|
mlirPDLPatternModuleRegisterConstraintFunction(
|
|
get(), mlirStringRefCreate(name.data(), name.size()),
|
|
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
|
|
size_t nValues, MlirPDLValue *values,
|
|
void *userData) -> MlirLogicalResult {
|
|
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
|
|
return logicalResultFromObject(
|
|
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
|
|
objectsFromPDLValues(nValues, values)));
|
|
},
|
|
fn.ptr());
|
|
}
|
|
|
|
private:
|
|
MlirPDLPatternModule module;
|
|
};
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
|
|
/// Owning Wrapper around a FrozenRewritePatternSet.
|
|
class PyFrozenRewritePatternSet {
|
|
public:
|
|
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
|
|
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
|
|
: set(other.set) {
|
|
other.set.ptr = nullptr;
|
|
}
|
|
~PyFrozenRewritePatternSet() {
|
|
if (set.ptr != nullptr)
|
|
mlirFrozenRewritePatternSetDestroy(set);
|
|
}
|
|
MlirFrozenRewritePatternSet get() { return set; }
|
|
|
|
nb::object getCapsule() {
|
|
return nb::steal<nb::object>(
|
|
mlirPythonFrozenRewritePatternSetToCapsule(get()));
|
|
}
|
|
|
|
static nb::object createFromCapsule(const nb::object &capsule) {
|
|
MlirFrozenRewritePatternSet rawPm =
|
|
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
|
|
if (rawPm.ptr == nullptr)
|
|
throw nb::python_error();
|
|
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
|
|
}
|
|
|
|
private:
|
|
MlirFrozenRewritePatternSet set;
|
|
};
|
|
|
|
void PyRewritePatternSet::bind(nb::module_ &m) {
|
|
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
|
|
.def(
|
|
"__init__",
|
|
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
|
|
new (&self) PyRewritePatternSet(context.get()->get());
|
|
},
|
|
"context"_a = nb::none())
|
|
.def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
|
|
nb::arg("benefit") = 1,
|
|
R"(Add a new rewrite pattern on the specified root operation, using
|
|
the provided callable for matching and rewriting, and assign it
|
|
the given benefit.
|
|
|
|
Args:
|
|
root: The root operation to which this pattern applies. This may
|
|
be either an OpView subclass or an operation name.
|
|
fn: The callable to use for matching and rewriting, which takes
|
|
an operation and a pattern rewriter. The match is considered
|
|
successful iff the callable returns a falsy value.
|
|
benefit: The benefit of the pattern, defaulting to 1.)")
|
|
.def("add_conversion", &PyRewritePatternSet::addConversion,
|
|
nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"),
|
|
nb::arg("benefit") = 1,
|
|
R"(
|
|
Add a new conversion pattern on the specified root operation,
|
|
using the provided callable for matching and rewriting,
|
|
and assign it the given benefit.
|
|
|
|
Args:
|
|
root: The root operation to which this pattern applies.
|
|
This may be either an OpView subclass or an operation name.
|
|
fn: The callable to use for matching and rewriting, which takes an
|
|
operation, its adaptor, the type converter and a pattern
|
|
rewriter. The match is considered successful iff the callable
|
|
returns a falsy value.
|
|
type_converter: The type converter to convert types in the IR.
|
|
benefit: The benefit of the pattern, defaulting to 1.)")
|
|
.def(
|
|
"freeze",
|
|
[](PyRewritePatternSet &self) {
|
|
if (!self.isOwned())
|
|
throw std::runtime_error(
|
|
"cannot freeze a non-owning pattern set");
|
|
MlirRewritePatternSet s = self.get();
|
|
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
|
|
},
|
|
"Freeze the pattern set into a frozen one.");
|
|
}
|
|
|
|
enum class PyGreedyRewriteStrictness : std::underlying_type_t<
|
|
MlirGreedyRewriteStrictness> {
|
|
ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
|
|
EXISTING_AND_NEW_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
|
|
EXISTING_OPS = MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS,
|
|
};
|
|
|
|
enum class PyGreedySimplifyRegionLevel : std::underlying_type_t<
|
|
MlirGreedySimplifyRegionLevel> {
|
|
DISABLED = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
|
|
NORMAL = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
|
|
AGGRESSIVE = MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
|
|
};
|
|
|
|
/// Owning Wrapper around a GreedyRewriteDriverConfig.
|
|
class PyGreedyRewriteConfig {
|
|
public:
|
|
PyGreedyRewriteConfig()
|
|
: config(mlirGreedyRewriteDriverConfigCreate().ptr,
|
|
PyGreedyRewriteConfig::customDeleter) {}
|
|
PyGreedyRewriteConfig(PyGreedyRewriteConfig &&other) noexcept
|
|
: config(std::move(other.config)) {}
|
|
PyGreedyRewriteConfig(const PyGreedyRewriteConfig &other) noexcept
|
|
: config(other.config) {}
|
|
|
|
MlirGreedyRewriteDriverConfig get() {
|
|
return MlirGreedyRewriteDriverConfig{config.get()};
|
|
}
|
|
|
|
void setMaxIterations(int64_t maxIterations) {
|
|
mlirGreedyRewriteDriverConfigSetMaxIterations(get(), maxIterations);
|
|
}
|
|
|
|
void setMaxNumRewrites(int64_t maxNumRewrites) {
|
|
mlirGreedyRewriteDriverConfigSetMaxNumRewrites(get(), maxNumRewrites);
|
|
}
|
|
|
|
void setUseTopDownTraversal(bool useTopDownTraversal) {
|
|
mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(get(),
|
|
useTopDownTraversal);
|
|
}
|
|
|
|
void enableFolding(bool enable) {
|
|
mlirGreedyRewriteDriverConfigEnableFolding(get(), enable);
|
|
}
|
|
|
|
void setStrictness(PyGreedyRewriteStrictness strictness) {
|
|
mlirGreedyRewriteDriverConfigSetStrictness(
|
|
get(), static_cast<MlirGreedyRewriteStrictness>(strictness));
|
|
}
|
|
|
|
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
|
|
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
|
|
get(), static_cast<MlirGreedySimplifyRegionLevel>(level));
|
|
}
|
|
|
|
void enableConstantCSE(bool enable) {
|
|
mlirGreedyRewriteDriverConfigEnableConstantCSE(get(), enable);
|
|
}
|
|
|
|
int64_t getMaxIterations() {
|
|
return mlirGreedyRewriteDriverConfigGetMaxIterations(get());
|
|
}
|
|
|
|
int64_t getMaxNumRewrites() {
|
|
return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(get());
|
|
}
|
|
|
|
bool getUseTopDownTraversal() {
|
|
return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(get());
|
|
}
|
|
|
|
bool isFoldingEnabled() {
|
|
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(get());
|
|
}
|
|
|
|
PyGreedyRewriteStrictness getStrictness() {
|
|
return static_cast<PyGreedyRewriteStrictness>(
|
|
mlirGreedyRewriteDriverConfigGetStrictness(get()));
|
|
}
|
|
|
|
PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
|
|
return static_cast<PyGreedySimplifyRegionLevel>(
|
|
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(get()));
|
|
}
|
|
|
|
bool isConstantCSEEnabled() {
|
|
return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(get());
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<void> config;
|
|
static void customDeleter(void *c) {
|
|
mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig{c});
|
|
}
|
|
};
|
|
|
|
enum class PyDialectConversionFoldingMode : std::underlying_type_t<
|
|
MlirDialectConversionFoldingMode> {
|
|
Never = MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER,
|
|
BeforePatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS,
|
|
AfterPatterns = MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS,
|
|
};
|
|
|
|
class PyConversionConfig {
|
|
public:
|
|
PyConversionConfig()
|
|
: config(mlirConversionConfigCreate().ptr,
|
|
PyConversionConfig::customDeleter) {}
|
|
|
|
MlirConversionConfig get() { return MlirConversionConfig{config.get()}; }
|
|
|
|
void setFoldingMode(PyDialectConversionFoldingMode mode) {
|
|
mlirConversionConfigSetFoldingMode(get(),
|
|
MlirDialectConversionFoldingMode(mode));
|
|
}
|
|
|
|
PyDialectConversionFoldingMode getFoldingMode() {
|
|
return PyDialectConversionFoldingMode(
|
|
mlirConversionConfigGetFoldingMode(get()));
|
|
}
|
|
|
|
void enableBuildMaterializations(bool enabled) {
|
|
mlirConversionConfigEnableBuildMaterializations(get(), enabled);
|
|
}
|
|
|
|
bool isBuildMaterializationsEnabled() {
|
|
return mlirConversionConfigIsBuildMaterializationsEnabled(get());
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<void> config;
|
|
static void customDeleter(void *c) {
|
|
mlirConversionConfigDestroy(MlirConversionConfig{c});
|
|
}
|
|
};
|
|
|
|
/// Create the `mlir.rewrite` here.
|
|
void populateRewriteSubmodule(nb::module_ &m) {
|
|
// Enum definitions
|
|
nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
|
|
.value("ANY_OP", PyGreedyRewriteStrictness::ANY_OP)
|
|
.value("EXISTING_AND_NEW_OPS",
|
|
PyGreedyRewriteStrictness::EXISTING_AND_NEW_OPS)
|
|
.value("EXISTING_OPS", PyGreedyRewriteStrictness::EXISTING_OPS);
|
|
|
|
nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
|
|
.value("DISABLED", PyGreedySimplifyRegionLevel::DISABLED)
|
|
.value("NORMAL", PyGreedySimplifyRegionLevel::NORMAL)
|
|
.value("AGGRESSIVE", PyGreedySimplifyRegionLevel::AGGRESSIVE);
|
|
|
|
nb::enum_<PyDialectConversionFoldingMode>(m, "DialectConversionFoldingMode")
|
|
.value("NEVER", PyDialectConversionFoldingMode::Never)
|
|
.value("BEFORE_PATTERNS", PyDialectConversionFoldingMode::BeforePatterns)
|
|
.value("AFTER_PATTERNS", PyDialectConversionFoldingMode::AfterPatterns);
|
|
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the PatternRewriter
|
|
//----------------------------------------------------------------------------
|
|
|
|
PyPatternRewriter::bind(m);
|
|
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the RewritePatternSet
|
|
//----------------------------------------------------------------------------
|
|
PyRewritePatternSet::bind(m);
|
|
|
|
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
|
|
m, "ConversionPatternRewriter")
|
|
.def("convert_region_types",
|
|
[](PyConversionPatternRewriter &self, PyRegion ®ion,
|
|
PyTypeConverter &typeConverter) {
|
|
mlirConversionPatternRewriterConvertRegionTypes(
|
|
self.rewriter, region.get(), typeConverter.get());
|
|
});
|
|
|
|
nb::class_<PyConversionTarget>(m, "ConversionTarget")
|
|
.def(
|
|
"__init__",
|
|
[](PyConversionTarget &self, DefaultingPyMlirContext context) {
|
|
new (&self) PyConversionTarget(context.get()->get());
|
|
},
|
|
"context"_a = nb::none())
|
|
.def(
|
|
"add_legal_op",
|
|
[](PyConversionTarget &self, const nb::args &ops) {
|
|
for (auto op : ops) {
|
|
self.addLegalOp(operationNameFromObject(op));
|
|
}
|
|
},
|
|
"ops"_a, "Mark the given operations as legal.")
|
|
.def(
|
|
"add_illegal_op",
|
|
[](PyConversionTarget &self, const nb::args &ops) {
|
|
for (auto op : ops) {
|
|
self.addIllegalOp(operationNameFromObject(op));
|
|
}
|
|
},
|
|
"ops"_a, "Mark the given operations as illegal.")
|
|
.def(
|
|
"add_legal_dialect",
|
|
[](PyConversionTarget &self, const nb::args &dialects) {
|
|
for (auto dialect : dialects) {
|
|
self.addLegalDialect(dialectNameFromObject(dialect));
|
|
}
|
|
},
|
|
"dialects"_a, "Mark the given dialects as legal.")
|
|
.def(
|
|
"add_illegal_dialect",
|
|
[](PyConversionTarget &self, const nb::args &dialects) {
|
|
for (auto dialect : dialects) {
|
|
self.addIllegalDialect(dialectNameFromObject(dialect));
|
|
}
|
|
},
|
|
"dialects"_a, "Mark the given dialect as illegal.");
|
|
|
|
nb::class_<PyTypeConverter>(m, "TypeConverter")
|
|
.def(nb::init<>(), "Create a new TypeConverter.")
|
|
.def("add_conversion", &PyTypeConverter::addConversion, "convert"_a,
|
|
nb::keep_alive<0, 1>(), "Register a type conversion function.")
|
|
.def("convert_type", &PyTypeConverter::convertType, "type"_a,
|
|
"Convert the given type. Returns None if conversion fails.");
|
|
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the PDLResultList and PDLModule
|
|
//----------------------------------------------------------------------------
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
|
|
.def("append",
|
|
[](PyMlirPDLResultList results, const PyValue &value) {
|
|
mlirPDLResultListPushBackValue(results, value);
|
|
})
|
|
.def("append",
|
|
[](PyMlirPDLResultList results, const PyOperation &op) {
|
|
mlirPDLResultListPushBackOperation(results, op);
|
|
})
|
|
.def("append",
|
|
[](PyMlirPDLResultList results, const PyType &type) {
|
|
mlirPDLResultListPushBackType(results, type);
|
|
})
|
|
.def("append", [](PyMlirPDLResultList results, const PyAttribute &attr) {
|
|
mlirPDLResultListPushBackAttribute(results, attr);
|
|
});
|
|
nb::class_<PyPDLPatternModule>(m, "PDLModule")
|
|
.def(
|
|
"__init__",
|
|
[](PyPDLPatternModule &self, PyModule &module) {
|
|
new (&self) PyPDLPatternModule(
|
|
mlirPDLPatternModuleFromModule(module.get()));
|
|
},
|
|
"module"_a, "Create a PDL module from the given module.")
|
|
.def(
|
|
"__init__",
|
|
[](PyPDLPatternModule &self, PyModule &module) {
|
|
new (&self) PyPDLPatternModule(
|
|
mlirPDLPatternModuleFromModule(module.get()));
|
|
},
|
|
"module"_a, "Create a PDL module from the given module.")
|
|
.def(
|
|
"freeze",
|
|
[](PyPDLPatternModule &self) {
|
|
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
|
|
mlirRewritePatternSetFromPDLPatternModule(self.get())));
|
|
},
|
|
nb::keep_alive<0, 1>())
|
|
.def(
|
|
"register_rewrite_function",
|
|
[](PyPDLPatternModule &self, const std::string &name,
|
|
const nb::callable &fn) {
|
|
self.registerRewriteFunction(name, fn);
|
|
},
|
|
nb::keep_alive<1, 3>())
|
|
.def(
|
|
"register_constraint_function",
|
|
[](PyPDLPatternModule &self, const std::string &name,
|
|
const nb::callable &fn) {
|
|
self.registerConstraintFunction(name, fn);
|
|
},
|
|
nb::keep_alive<1, 3>());
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
|
|
nb::class_<PyGreedyRewriteConfig>(m, "GreedyRewriteConfig")
|
|
.def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
|
|
.def_prop_rw("max_iterations", &PyGreedyRewriteConfig::getMaxIterations,
|
|
&PyGreedyRewriteConfig::setMaxIterations,
|
|
"Maximum number of iterations")
|
|
.def_prop_rw("max_num_rewrites",
|
|
&PyGreedyRewriteConfig::getMaxNumRewrites,
|
|
&PyGreedyRewriteConfig::setMaxNumRewrites,
|
|
"Maximum number of rewrites per iteration")
|
|
.def_prop_rw("use_top_down_traversal",
|
|
&PyGreedyRewriteConfig::getUseTopDownTraversal,
|
|
&PyGreedyRewriteConfig::setUseTopDownTraversal,
|
|
"Whether to use top-down traversal")
|
|
.def_prop_rw("enable_folding", &PyGreedyRewriteConfig::isFoldingEnabled,
|
|
&PyGreedyRewriteConfig::enableFolding,
|
|
"Enable or disable folding")
|
|
.def_prop_rw("strictness", &PyGreedyRewriteConfig::getStrictness,
|
|
&PyGreedyRewriteConfig::setStrictness,
|
|
"Rewrite strictness level")
|
|
.def_prop_rw("region_simplification_level",
|
|
&PyGreedyRewriteConfig::getRegionSimplificationLevel,
|
|
&PyGreedyRewriteConfig::setRegionSimplificationLevel,
|
|
"Region simplification level")
|
|
.def_prop_rw("enable_constant_cse",
|
|
&PyGreedyRewriteConfig::isConstantCSEEnabled,
|
|
&PyGreedyRewriteConfig::enableConstantCSE,
|
|
"Enable or disable constant CSE");
|
|
|
|
nb::class_<PyConversionConfig>(m, "ConversionConfig")
|
|
.def(nb::init<>(), "Create a conversion config with defaults")
|
|
.def_prop_rw("folding_mode", &PyConversionConfig::getFoldingMode,
|
|
&PyConversionConfig::setFoldingMode,
|
|
"folding behavior during dialect conversion")
|
|
.def_prop_rw("build_materializations",
|
|
&PyConversionConfig::isBuildMaterializationsEnabled,
|
|
&PyConversionConfig::enableBuildMaterializations,
|
|
"Whether the dialect conversion attempts to build "
|
|
"source/target materializations");
|
|
|
|
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
|
|
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
|
|
&PyFrozenRewritePatternSet::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
|
|
&PyFrozenRewritePatternSet::createFromCapsule);
|
|
m.def(
|
|
"apply_patterns_and_fold_greedily",
|
|
[](PyModule &module, PyFrozenRewritePatternSet &set,
|
|
std::optional<PyGreedyRewriteConfig> config) {
|
|
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedily(
|
|
module.get(), set.get(),
|
|
config.has_value() ? config->get()
|
|
: mlirGreedyRewriteDriverConfigCreate());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw std::runtime_error("pattern application failed to converge");
|
|
},
|
|
"module"_a, "set"_a, "config"_a = nb::none(),
|
|
"Applys the given patterns to the given module greedily while folding "
|
|
"results.")
|
|
.def(
|
|
"apply_patterns_and_fold_greedily",
|
|
[](PyOperationBase &op, PyFrozenRewritePatternSet &set,
|
|
std::optional<PyGreedyRewriteConfig> config) {
|
|
MlirLogicalResult status = mlirApplyPatternsAndFoldGreedilyWithOp(
|
|
op.getOperation(), set.get(),
|
|
config.has_value() ? config->get()
|
|
: mlirGreedyRewriteDriverConfigCreate());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw std::runtime_error(
|
|
"pattern application failed to converge");
|
|
},
|
|
"op"_a, "set"_a, "config"_a = nb::none(),
|
|
"Applys the given patterns to the given op greedily while folding "
|
|
"results.")
|
|
.def(
|
|
"walk_and_apply_patterns",
|
|
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
|
|
mlirWalkAndApplyPatterns(op.getOperation(), set.get());
|
|
},
|
|
"op"_a, "set"_a,
|
|
"Applies the given patterns to the given op by a fast walk-based "
|
|
"driver.")
|
|
.def(
|
|
"apply_partial_conversion",
|
|
[](PyOperationBase &op, PyConversionTarget &target,
|
|
PyFrozenRewritePatternSet &set,
|
|
std::optional<PyConversionConfig> config) {
|
|
if (!config)
|
|
config.emplace(PyConversionConfig());
|
|
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
|
|
MlirLogicalResult status = mlirApplyPartialConversion(
|
|
op.getOperation(), target.get(), set.get(), config->get());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw MLIRError("partial conversion failed", errors.take());
|
|
},
|
|
"op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
|
|
"Applies a partial conversion on the given operation.")
|
|
.def(
|
|
"apply_full_conversion",
|
|
[](PyOperationBase &op, PyConversionTarget &target,
|
|
PyFrozenRewritePatternSet &set,
|
|
std::optional<PyConversionConfig> config) {
|
|
if (!config)
|
|
config.emplace(PyConversionConfig());
|
|
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
|
|
MlirLogicalResult status = mlirApplyFullConversion(
|
|
op.getOperation(), target.get(), set.get(), config->get());
|
|
if (mlirLogicalResultIsFailure(status))
|
|
throw MLIRError("full conversion failed", errors.take());
|
|
},
|
|
"op"_a, "target"_a, "set"_a, "config"_a = nb::none(),
|
|
"Applies a full conversion on the given operation.");
|
|
}
|
|
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
|
} // namespace python
|
|
} // namespace mlir
|