[MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (#184331)
Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter. With assistance of Claude.
This commit is contained in:
@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirTransformOpInterfaceCallbacks callbacks);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// PatternDescriptorOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the PatternDescriptorOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
|
||||
|
||||
/// Callbacks for implementing PatternDescriptorOpInterface from external code.
|
||||
typedef struct {
|
||||
/// Optional constructor for the user data.
|
||||
/// Set to nullptr to disable it.
|
||||
void (*construct)(void *userData);
|
||||
/// Optional destructor for the user data.
|
||||
/// Set to nullptr to disable it.
|
||||
void (*destruct)(void *userData);
|
||||
/// Callback to populate rewrite patterns into the given pattern set.
|
||||
void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
|
||||
void *userData);
|
||||
/// Optional callback to populate rewrite patterns with transform state.
|
||||
/// Set to nullptr to use the default implementation (calls populatePatterns).
|
||||
void (*populatePatternsWithState)(MlirOperation op,
|
||||
MlirRewritePatternSet patterns,
|
||||
MlirTransformState state, void *userData);
|
||||
void *userData;
|
||||
} MlirPatternDescriptorOpInterfaceCallbacks;
|
||||
|
||||
/// Attach PatternDescriptorOpInterface to the operation with the given name
|
||||
/// using the provided callbacks.
|
||||
MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirPatternDescriptorOpInterfaceCallbacks callbacks);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Transform-specifc MemoryEffectsOpInterface helpers
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
|
||||
MLIR_CAPI_EXPORTED MlirRewritePatternSet
|
||||
mlirRewritePatternSetCreate(MlirContext context);
|
||||
|
||||
/// Get the context associated with a MlirRewritePatternSet.
|
||||
MLIR_CAPI_EXPORTED MlirContext
|
||||
mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
|
||||
|
||||
/// Destruct the given MlirRewritePatternSet.
|
||||
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "Rewrite.h"
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Rewrite.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
#include "mlir/Bindings/Python/IRInterfaces.h"
|
||||
@@ -246,6 +247,104 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternDescriptorOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PyPatternDescriptorOpInterface
|
||||
: public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
|
||||
public:
|
||||
using PyConcreteOpInterface<
|
||||
PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
|
||||
|
||||
constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
|
||||
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
||||
&mlirPatternDescriptorOpInterfaceTypeID;
|
||||
|
||||
/// Attach a new PatternDescriptorOpInterface FallbackModel to the named
|
||||
/// operation. The FallbackModel acts as a trampoline for callbacks on the
|
||||
/// Python class.
|
||||
static void attach(nb::object &target, const std::string &opName,
|
||||
DefaultingPyMlirContext ctx) {
|
||||
// Prepare the callbacks that will be used by the FallbackModel.
|
||||
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
|
||||
// Make the pointer to the Python class available to the callbacks.
|
||||
callbacks.userData = target.ptr();
|
||||
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
|
||||
|
||||
// The above ref bump is all we need as initialization, no need to run the
|
||||
// construct callback.
|
||||
callbacks.construct = nullptr;
|
||||
// Upon the FallbackModel's destruction, drop the ref to the Python class.
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
|
||||
// The populatePatterns callback which calls into Python.
|
||||
callbacks.populatePatterns =
|
||||
[](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
|
||||
nb::handle pyClass(static_cast<PyObject *>(userData));
|
||||
|
||||
auto pyPopulatePatterns =
|
||||
nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
|
||||
|
||||
auto pyPatterns = PyRewritePatternSet(patterns);
|
||||
|
||||
// Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
|
||||
// staticmethod.
|
||||
MlirContext ctx = mlirOperationGetContext(op);
|
||||
PyMlirContextRef context = PyMlirContext::forContext(ctx);
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
pyPopulatePatterns(opview, pyPatterns);
|
||||
};
|
||||
|
||||
// The populatePatternsWithState callback which calls into Python.
|
||||
// Check if the Python class has populate_patterns_with_state method.
|
||||
if (nb::hasattr(target, "populate_patterns_with_state")) {
|
||||
callbacks.populatePatternsWithState = [](MlirOperation op,
|
||||
MlirRewritePatternSet patterns,
|
||||
MlirTransformState state,
|
||||
void *userData) {
|
||||
nb::handle pyClass(static_cast<PyObject *>(userData));
|
||||
|
||||
auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
|
||||
nb::getattr(pyClass, "populate_patterns_with_state"));
|
||||
|
||||
auto pyPatterns = PyRewritePatternSet(patterns);
|
||||
auto pyState = PyTransformState(state);
|
||||
|
||||
// Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
|
||||
// state)` as a staticmethod.
|
||||
MlirContext ctx = mlirOperationGetContext(op);
|
||||
PyMlirContextRef context = PyMlirContext::forContext(ctx);
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
pyPopulatePatternsWithState(opview, pyPatterns, pyState);
|
||||
};
|
||||
} else {
|
||||
// Use default implementation (will call populatePatterns).
|
||||
callbacks.populatePatternsWithState = nullptr;
|
||||
}
|
||||
|
||||
// Attach a FallbackModel, which calls into Python, to the named operation.
|
||||
mlirPatternDescriptorOpInterfaceAttachFallbackModel(
|
||||
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
|
||||
callbacks);
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.attr("attach") = classmethod(
|
||||
[](const nb::object &cls, const nb::object &opName, nb::object target,
|
||||
DefaultingPyMlirContext context) {
|
||||
if (target.is_none())
|
||||
target = cls;
|
||||
return attach(target, nb::cast<std::string>(opName), context);
|
||||
},
|
||||
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
|
||||
nb::arg("target").none() = nb::none(),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Attach the interface subclass to the given operation name.");
|
||||
}
|
||||
};
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===-------------------------------------------------------------------===//
|
||||
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
|
||||
PyTransformResults::bind(m);
|
||||
PyTransformState::bind(m);
|
||||
PyTransformOpInterface::bind(m);
|
||||
PyPatternDescriptorOpInterface::bind(m);
|
||||
|
||||
m.def("only_reads_handle", onlyReadsHandle,
|
||||
"Mark operands as only reading handles.", nb::arg("operands"),
|
||||
|
||||
@@ -27,6 +27,18 @@ namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
|
||||
// Convert the Python object to a boolean.
|
||||
// If it evaluates to False, treat it as success;
|
||||
// otherwise, treat it as failure.
|
||||
// Note that None is considered success.
|
||||
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
|
||||
if (obj.is_none())
|
||||
return mlirLogicalResultSuccess();
|
||||
|
||||
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
|
||||
: mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "PatternRewriter";
|
||||
@@ -35,6 +47,70 @@ public:
|
||||
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyRewritePatternSet
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
|
||||
: patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
|
||||
|
||||
PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
|
||||
: patterns(patterns), owned(false) {}
|
||||
|
||||
PyRewritePatternSet::~PyRewritePatternSet() {
|
||||
if (owned && patterns.ptr)
|
||||
mlirRewritePatternSetDestroy(patterns);
|
||||
}
|
||||
|
||||
MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
|
||||
|
||||
bool PyRewritePatternSet::isOwned() const { return owned; }
|
||||
|
||||
void PyRewritePatternSet::add(nb::handle root,
|
||||
const nb::callable &matchAndRewrite,
|
||||
unsigned benefit) {
|
||||
std::string opName;
|
||||
if (root.is_type()) {
|
||||
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
|
||||
} else if (nb::isinstance<nb::str>(root)) {
|
||||
opName = nb::cast<std::string>(root);
|
||||
} else {
|
||||
throw nb::type_error("the root argument must be a type or a string");
|
||||
}
|
||||
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
|
||||
|
||||
MlirRewritePatternCallbacks callbacks;
|
||||
callbacks.construct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
||||
};
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
|
||||
MlirPatternRewriter rewriter,
|
||||
void *userData) -> MlirLogicalResult {
|
||||
nb::handle f(static_cast<PyObject *>(userData));
|
||||
|
||||
PyMlirContextRef context =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
nb::object opView = PyOperation::forOperation(context, op)->createOpView();
|
||||
|
||||
nb::object res = f(opView, PyPatternRewriter(rewriter));
|
||||
return logicalResultFromObject(res);
|
||||
};
|
||||
|
||||
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
|
||||
rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks,
|
||||
matchAndRewrite.ptr(),
|
||||
/* nGeneratedNames */ 0,
|
||||
/* generatedNames */ nullptr);
|
||||
mlirRewritePatternSetAdd(patterns, pattern);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PyConversionPatternRewriter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class PyConversionPatternRewriter : public PyPatternRewriter {
|
||||
public:
|
||||
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
|
||||
@@ -132,6 +208,60 @@ private:
|
||||
MlirConversionPattern pattern;
|
||||
};
|
||||
|
||||
void PyRewritePatternSet::addConversion(nb::handle root,
|
||||
const nb::callable &matchAndRewrite,
|
||||
PyTypeConverter &typeConverter,
|
||||
unsigned benefit) {
|
||||
std::string opName;
|
||||
if (root.is_type()) {
|
||||
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
|
||||
} else if (nb::isinstance<nb::str>(root)) {
|
||||
opName = nb::cast<std::string>(root);
|
||||
} else {
|
||||
throw nb::type_error("the root argument must be a type or a string");
|
||||
}
|
||||
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
|
||||
|
||||
MlirConversionPatternCallbacks callbacks;
|
||||
callbacks.construct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
||||
};
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
callbacks.matchAndRewrite =
|
||||
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
|
||||
MlirValue *operands, MlirConversionPatternRewriter rewriter,
|
||||
void *userData) -> MlirLogicalResult {
|
||||
nb::handle f(static_cast<PyObject *>(userData));
|
||||
|
||||
PyMlirContextRef ctx =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
|
||||
|
||||
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
|
||||
nb::object adaptorCls =
|
||||
PyGlobals::get()
|
||||
.lookupOpAdaptorClass([&] {
|
||||
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
|
||||
return std::string_view(ref.data, ref.length);
|
||||
}())
|
||||
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
|
||||
|
||||
nb::object res = f(opView, adaptorCls(operandsVec, opView),
|
||||
PyConversionPattern(pattern).getTypeConverter(),
|
||||
PyConversionPatternRewriter(rewriter));
|
||||
return logicalResultFromObject(res);
|
||||
};
|
||||
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
|
||||
rootName, benefit, mlirRewritePatternSetGetContext(patterns),
|
||||
typeConverter.get(), callbacks, matchAndRewrite.ptr(),
|
||||
/* nGeneratedNames */ 0,
|
||||
/* generatedNames */ nullptr);
|
||||
mlirRewritePatternSetAdd(patterns,
|
||||
mlirConversionPatternAsRewritePattern(pattern));
|
||||
}
|
||||
|
||||
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
||||
struct PyMlirPDLResultList : MlirPDLResultList {};
|
||||
|
||||
@@ -157,18 +287,6 @@ static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
|
||||
return args;
|
||||
}
|
||||
|
||||
// Convert the Python object to a boolean.
|
||||
// If it evaluates to False, treat it as success;
|
||||
// otherwise, treat it as failure.
|
||||
// Note that None is considered success.
|
||||
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
|
||||
if (obj.is_none())
|
||||
return mlirLogicalResultSuccess();
|
||||
|
||||
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
|
||||
: mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
/// Owning Wrapper around a PDLPatternModule.
|
||||
class PyPDLPatternModule {
|
||||
public:
|
||||
@@ -249,96 +367,55 @@ private:
|
||||
MlirFrozenRewritePatternSet set;
|
||||
};
|
||||
|
||||
class PyRewritePatternSet {
|
||||
public:
|
||||
PyRewritePatternSet(MlirContext ctx)
|
||||
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
|
||||
~PyRewritePatternSet() {
|
||||
if (set.ptr)
|
||||
mlirRewritePatternSetDestroy(set);
|
||||
}
|
||||
void PyRewritePatternSet::bind(nb::module_ &m) {
|
||||
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
|
||||
.def(
|
||||
"__init__",
|
||||
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
|
||||
new (&self) PyRewritePatternSet(context.get()->get());
|
||||
},
|
||||
"context"_a = nb::none())
|
||||
.def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
|
||||
nb::arg("benefit") = 1,
|
||||
R"(Add a new rewrite pattern on the specified root operation, using
|
||||
the provided callable for matching and rewriting, and assign it
|
||||
the given benefit.
|
||||
|
||||
void add(MlirStringRef rootName, unsigned benefit,
|
||||
const nb::callable &matchAndRewrite) {
|
||||
MlirRewritePatternCallbacks callbacks;
|
||||
callbacks.construct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
||||
};
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
|
||||
MlirPatternRewriter rewriter,
|
||||
void *userData) -> MlirLogicalResult {
|
||||
nb::handle f(static_cast<PyObject *>(userData));
|
||||
Args:
|
||||
root: The root operation to which this pattern applies. This may
|
||||
be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes
|
||||
an operation and a pattern rewriter. The match is considered
|
||||
successful iff the callable returns a falsy value.
|
||||
benefit: The benefit of the pattern, defaulting to 1.)")
|
||||
.def("add_conversion", &PyRewritePatternSet::addConversion,
|
||||
nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"),
|
||||
nb::arg("benefit") = 1,
|
||||
R"(
|
||||
Add a new conversion pattern on the specified root operation,
|
||||
using the provided callable for matching and rewriting,
|
||||
and assign it the given benefit.
|
||||
|
||||
PyMlirContextRef ctx =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
|
||||
|
||||
nb::object res = f(opView, PyPatternRewriter(rewriter));
|
||||
return logicalResultFromObject(res);
|
||||
};
|
||||
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
|
||||
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
|
||||
/* nGeneratedNames */ 0,
|
||||
/* generatedNames */ nullptr);
|
||||
mlirRewritePatternSetAdd(set, pattern);
|
||||
}
|
||||
|
||||
void addConversion(MlirStringRef rootName, unsigned benefit,
|
||||
const nb::callable &matchAndRewrite,
|
||||
PyTypeConverter &typeConverter) {
|
||||
MlirConversionPatternCallbacks callbacks;
|
||||
callbacks.construct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
|
||||
};
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
callbacks.matchAndRewrite =
|
||||
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
|
||||
MlirValue *operands, MlirConversionPatternRewriter rewriter,
|
||||
void *userData) -> MlirLogicalResult {
|
||||
nb::handle f(static_cast<PyObject *>(userData));
|
||||
|
||||
PyMlirContextRef ctx =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
|
||||
|
||||
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
|
||||
nb::object adaptorCls =
|
||||
PyGlobals::get()
|
||||
.lookupOpAdaptorClass([&] {
|
||||
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
|
||||
return std::string_view(ref.data, ref.length);
|
||||
}())
|
||||
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
|
||||
|
||||
nb::object res = f(opView, adaptorCls(operandsVec, opView),
|
||||
PyConversionPattern(pattern).getTypeConverter(),
|
||||
PyConversionPatternRewriter(rewriter));
|
||||
return logicalResultFromObject(res);
|
||||
};
|
||||
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
|
||||
rootName, benefit, ctx, typeConverter.get(), callbacks,
|
||||
matchAndRewrite.ptr(),
|
||||
/* nGeneratedNames */ 0,
|
||||
/* generatedNames */ nullptr);
|
||||
mlirRewritePatternSetAdd(set,
|
||||
mlirConversionPatternAsRewritePattern(pattern));
|
||||
}
|
||||
|
||||
PyFrozenRewritePatternSet freeze() {
|
||||
MlirRewritePatternSet s = set;
|
||||
set.ptr = nullptr;
|
||||
return mlirFreezeRewritePattern(s);
|
||||
}
|
||||
|
||||
private:
|
||||
MlirRewritePatternSet set;
|
||||
MlirContext ctx;
|
||||
};
|
||||
Args:
|
||||
root: The root operation to which this pattern applies.
|
||||
This may be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes an
|
||||
operation, its adaptor, the type converter and a pattern
|
||||
rewriter. The match is considered successful iff the callable
|
||||
returns a falsy value.
|
||||
type_converter: The type converter to convert types in the IR.
|
||||
benefit: The benefit of the pattern, defaulting to 1.)")
|
||||
.def(
|
||||
"freeze",
|
||||
[](PyRewritePatternSet &self) {
|
||||
if (!self.isOwned())
|
||||
throw std::runtime_error(
|
||||
"cannot freeze a non-owning pattern set");
|
||||
MlirRewritePatternSet s = self.get();
|
||||
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
|
||||
},
|
||||
"Freeze the pattern set into a frozen one.");
|
||||
}
|
||||
|
||||
enum class PyGreedyRewriteStrictness : std::underlying_type_t<
|
||||
MlirGreedyRewriteStrictness> {
|
||||
@@ -505,79 +582,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of the RewritePatternSet
|
||||
//----------------------------------------------------------------------------
|
||||
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
|
||||
.def(
|
||||
"__init__",
|
||||
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
|
||||
new (&self) PyRewritePatternSet(context.get()->get());
|
||||
},
|
||||
"context"_a = nb::none())
|
||||
.def(
|
||||
"add",
|
||||
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
|
||||
unsigned benefit) {
|
||||
std::string opName;
|
||||
if (root.is_type()) {
|
||||
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
|
||||
} else if (nb::isinstance<nb::str>(root)) {
|
||||
opName = nb::cast<std::string>(root);
|
||||
} else {
|
||||
throw nb::type_error(
|
||||
"the root argument must be a type or a string");
|
||||
}
|
||||
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
|
||||
fn);
|
||||
},
|
||||
"root"_a, "fn"_a, "benefit"_a = 1,
|
||||
// clang-format off
|
||||
nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
|
||||
// clang-format on
|
||||
R"(
|
||||
Add a new rewrite pattern on the specified root operation, using the provided callable
|
||||
for matching and rewriting, and assign it the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies.
|
||||
This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
|
||||
an operation name string (e.g., ``"arith.addi"``).
|
||||
fn: The callable to use for matching and rewriting,
|
||||
which takes an operation and a pattern rewriter as arguments.
|
||||
The match is considered successful iff the callable returns
|
||||
a value where ``bool(value)`` is ``False`` (e.g. ``None``).
|
||||
If possible, the operation is cast to its corresponding OpView subclass
|
||||
before being passed to the callable.
|
||||
benefit: The benefit of the pattern, defaulting to 1.)")
|
||||
.def(
|
||||
"add_conversion",
|
||||
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
|
||||
PyTypeConverter &typeConverter, unsigned benefit) {
|
||||
std::string opName =
|
||||
nb::cast<std::string>(root.attr("OPERATION_NAME"));
|
||||
self.addConversion(
|
||||
mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
|
||||
typeConverter);
|
||||
},
|
||||
"root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
|
||||
R"(
|
||||
Add a new conversion pattern on the specified root operation,
|
||||
using the provided callable for matching and rewriting,
|
||||
and assign it the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies.
|
||||
This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
|
||||
an operation name string (e.g., ``"arith.addi"``).
|
||||
fn: The callable to use for matching and rewriting,
|
||||
which takes an operation, its adaptor,
|
||||
the type converter and a pattern rewriter as arguments.
|
||||
The match is considered successful iff the callable returns
|
||||
a value where ``bool(value)`` is ``False`` (e.g. ``None``).
|
||||
If possible, the operation is cast to its corresponding OpView subclass
|
||||
before being passed to the callable.
|
||||
type_converter: The type converter to convert types in the IR.
|
||||
benefit: The benefit of the pattern, defaulting to 1.)")
|
||||
.def("freeze", &PyRewritePatternSet::freeze,
|
||||
"Freeze the pattern set into a frozen one.");
|
||||
PyRewritePatternSet::bind(m);
|
||||
|
||||
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
|
||||
m, "ConversionPatternRewriter")
|
||||
|
||||
@@ -75,6 +75,41 @@ private:
|
||||
PyMlirContextRef ctx;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirRewritePatternSet.
|
||||
/// The default constructor creates an owned pattern set that is destroyed
|
||||
/// in the destructor. The constructor taking MlirRewritePatternSet creates
|
||||
/// a non-owning reference.
|
||||
class PyTypeConverter;
|
||||
class MLIR_PYTHON_API_EXPORTED PyRewritePatternSet {
|
||||
public:
|
||||
/// Create an owned pattern set.
|
||||
PyRewritePatternSet(MlirContext ctx);
|
||||
|
||||
/// Create a non-owning reference to an existing pattern set.
|
||||
PyRewritePatternSet(MlirRewritePatternSet patterns);
|
||||
|
||||
~PyRewritePatternSet();
|
||||
|
||||
MlirRewritePatternSet get() const;
|
||||
|
||||
bool isOwned() const;
|
||||
|
||||
/// Add a new rewrite pattern to the pattern set.
|
||||
void add(nanobind::handle root, const nanobind::callable &matchAndRewrite,
|
||||
unsigned benefit);
|
||||
|
||||
/// Add a new conversion pattern to the pattern set.
|
||||
void addConversion(nanobind::handle root,
|
||||
const nanobind::callable &matchAndRewrite,
|
||||
PyTypeConverter &typeConverter, unsigned benefit);
|
||||
|
||||
static void bind(nanobind::module_ &m);
|
||||
|
||||
private:
|
||||
MlirRewritePatternSet patterns;
|
||||
bool owned;
|
||||
};
|
||||
|
||||
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
} // namespace python
|
||||
|
||||
@@ -298,6 +298,89 @@ void mlirTransformOpInterfaceAttachFallbackModel(
|
||||
model->setCallbacks(callbacks);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// PatternDescriptorOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
|
||||
return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
|
||||
}
|
||||
|
||||
/// Fallback model for the PatternDescriptorOpInterface that uses C API
|
||||
/// callbacks.
|
||||
class PatternDescriptorOpInterfaceFallbackModel
|
||||
: public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
|
||||
PatternDescriptorOpInterfaceFallbackModel> {
|
||||
public:
|
||||
/// Sets the callbacks that this FallbackModel will use.
|
||||
/// NB: the callbacks can only be set through this method as the
|
||||
/// RegisteredOperationName::attachInterface mechanism default-constructs
|
||||
/// the FallbackModel without being able to provide arguments.
|
||||
void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
|
||||
this->callbacks = callbacks;
|
||||
}
|
||||
|
||||
~PatternDescriptorOpInterfaceFallbackModel() {
|
||||
if (callbacks.destruct)
|
||||
callbacks.destruct(callbacks.userData);
|
||||
}
|
||||
|
||||
static TypeID getInterfaceID() {
|
||||
return transform::PatternDescriptorOpInterface::getInterfaceID();
|
||||
}
|
||||
|
||||
static bool
|
||||
classof(const mlir::transform::detail::
|
||||
PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
|
||||
// Enable casting back to the FallbackModel from the Interface. This is
|
||||
// necessary as attachInterface(...) default-constructs the FallbackModel
|
||||
// without being able to pass in the callbacks and returns just the Concept.
|
||||
return true;
|
||||
}
|
||||
|
||||
void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
|
||||
assert(callbacks.populatePatterns && "populatePatterns callback not set");
|
||||
callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
|
||||
}
|
||||
|
||||
void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
|
||||
transform::TransformState &state) const {
|
||||
if (callbacks.populatePatternsWithState) {
|
||||
callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
|
||||
wrap(&state), callbacks.userData);
|
||||
} else {
|
||||
// Default implementation: call populatePatterns without state.
|
||||
populatePatterns(op, patterns);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
|
||||
};
|
||||
|
||||
/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
|
||||
/// operation. The FallbackModel uses the provided callbacks to implement the
|
||||
/// interface.
|
||||
void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
|
||||
// Look up the operation definition in the context.
|
||||
std::optional<RegisteredOperationName> opInfo =
|
||||
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
|
||||
|
||||
assert(opInfo.has_value() && "operation not found in context");
|
||||
|
||||
// NB: the following default-constructs the FallbackModel _without_ being able
|
||||
// to provide arguments.
|
||||
opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
|
||||
// Cast to get the underlying FallbackModel and set the callbacks.
|
||||
auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
|
||||
opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
|
||||
|
||||
assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
|
||||
model->setCallbacks(callbacks);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MemoryEffectsOpInterface helpers
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
@@ -728,6 +728,10 @@ MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
|
||||
return wrap(new mlir::RewritePatternSet(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) {
|
||||
return wrap(unwrap(set)->getContext());
|
||||
}
|
||||
|
||||
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
|
||||
delete unwrap(set);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from mlir import ir, rewrite
|
||||
from mlir.dialects import transform, func, arith, ext
|
||||
from mlir.dialects.transform import AnyOpType, structured
|
||||
|
||||
|
||||
@ext.register_dialect
|
||||
class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
|
||||
pass
|
||||
|
||||
|
||||
def run(emit_schedule):
|
||||
print(f"Test: {emit_schedule.__name__}")
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
payload = emit_payload()
|
||||
|
||||
MyPatternDescriptors.load(register=False, reload=True)
|
||||
|
||||
# NB: Pattern descriptor ops have their interfaces attached
|
||||
# in their respective test functions.
|
||||
schedule = emit_schedule()
|
||||
|
||||
(_named_seq := schedule.body.operations[0]).apply(payload)
|
||||
|
||||
print(payload)
|
||||
|
||||
|
||||
# Payload used by all tests.
|
||||
def emit_payload():
|
||||
payload_module = ir.Module.create()
|
||||
with ir.InsertionPoint(payload_module.body):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
|
||||
@func.FuncOp.from_py_func(i32, i32)
|
||||
def test_func(a, b):
|
||||
c = arith.addi(a, b)
|
||||
d = arith.subi(c, b)
|
||||
return d
|
||||
|
||||
return payload_module
|
||||
|
||||
|
||||
@contextmanager
|
||||
def schedule_boilerplate():
|
||||
schedule = ir.Module.create()
|
||||
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
|
||||
with ir.InsertionPoint(schedule.body):
|
||||
named_sequence = transform.NamedSequenceOp(
|
||||
"__transform_main",
|
||||
[AnyOpType.get()],
|
||||
[AnyOpType.get()],
|
||||
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
|
||||
)
|
||||
with ir.InsertionPoint(named_sequence.body):
|
||||
yield schedule, named_sequence
|
||||
|
||||
|
||||
@ext.register_operation(MyPatternDescriptors)
|
||||
class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
|
||||
@classmethod
|
||||
def attach_interface_impls(cls, ctx=None):
|
||||
cls.PatternDescriptorOpInterfaceFallbackModel.attach(
|
||||
cls.OPERATION_NAME, context=ctx
|
||||
)
|
||||
|
||||
class PatternDescriptorOpInterfaceFallbackModel(
|
||||
transform.PatternDescriptorOpInterface
|
||||
):
|
||||
@staticmethod
|
||||
def populate_patterns(
|
||||
op: "SubiAddiRewritePatternOp",
|
||||
patterns: rewrite.RewritePatternSet,
|
||||
) -> None:
|
||||
# Define a pattern that rewrites subi(addi(a, b), b) -> a
|
||||
def match_and_rewrite(subi, rewriter):
|
||||
if not isinstance(addi := subi.lhs.owner, arith.AddiOp):
|
||||
return True # Failed match, return truthy value
|
||||
if subi.rhs != addi.rhs:
|
||||
return True
|
||||
# Replace subi's result with addi's lhs
|
||||
rewriter.replace_op(subi, [addi.lhs])
|
||||
return None # Success
|
||||
|
||||
# Add the pattern to the pattern set.
|
||||
patterns.add("arith.subi", match_and_rewrite, benefit=1)
|
||||
|
||||
|
||||
# CHECK-LABEL: Test: test_pattern_descriptor_add_pattern
|
||||
@run
|
||||
def test_pattern_descriptor_add_pattern():
|
||||
"""Tests python-defined rewrite pattern via PatternDescriptorOpInterface on AddPatternOp"""
|
||||
|
||||
SubiAddiRewritePatternOp.attach_interface_impls()
|
||||
|
||||
with schedule_boilerplate() as (schedule, named_seq):
|
||||
func_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["func.func"]
|
||||
).result
|
||||
|
||||
# After pattern application, check that subi is removed and func returns
|
||||
# the first argument directly:
|
||||
# CHECK: func.func @test_func(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
|
||||
# CHECK: return %[[ARG0]] : i32
|
||||
apply_patterns_op = transform.ApplyPatternsOp(func_handle)
|
||||
with ir.InsertionPoint(apply_patterns_op.patterns):
|
||||
SubiAddiRewritePatternOp()
|
||||
|
||||
transform.yield_([func_handle])
|
||||
named_seq.verify()
|
||||
|
||||
return schedule
|
||||
Reference in New Issue
Block a user