[MLIR][Python][Transform] Expose PatternDescriptorOpInterface to Python (#184331)

Makes it possible to include Python-defined rewrite patterns in
transform-dialect schedules, inside of `transform.apply_patterns`, which
upon execution of the schedule runs the pattern in a greedy rewriter.

With assistance of Claude.
This commit is contained in:
Rolf Morel
2026-03-04 11:19:59 +01:00
committed by GitHub
parent 9cc0df99de
commit 756d068ead
8 changed files with 550 additions and 173 deletions

View File

@@ -207,6 +207,38 @@ MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks);
//===---------------------------------------------------------------------===//
// PatternDescriptorOpInterface
//===---------------------------------------------------------------------===//
/// Returns the interface TypeID of the PatternDescriptorOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void);
/// Callbacks for implementing PatternDescriptorOpInterface from external code.
typedef struct {
/// Optional constructor for the user data.
/// Set to nullptr to disable it.
void (*construct)(void *userData);
/// Optional destructor for the user data.
/// Set to nullptr to disable it.
void (*destruct)(void *userData);
/// Callback to populate rewrite patterns into the given pattern set.
void (*populatePatterns)(MlirOperation op, MlirRewritePatternSet patterns,
void *userData);
/// Optional callback to populate rewrite patterns with transform state.
/// Set to nullptr to use the default implementation (calls populatePatterns).
void (*populatePatternsWithState)(MlirOperation op,
MlirRewritePatternSet patterns,
MlirTransformState state, void *userData);
void *userData;
} MlirPatternDescriptorOpInterfaceCallbacks;
/// Attach PatternDescriptorOpInterface to the operation with the given name
/// using the provided callbacks.
MLIR_CAPI_EXPORTED void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirPatternDescriptorOpInterfaceCallbacks callbacks);
//===---------------------------------------------------------------------===//
// Transform-specifc MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//

View File

@@ -628,6 +628,10 @@ MLIR_CAPI_EXPORTED MlirRewritePattern mlirOpRewritePatternCreate(
MLIR_CAPI_EXPORTED MlirRewritePatternSet
mlirRewritePatternSetCreate(MlirContext context);
/// Get the context associated with a MlirRewritePatternSet.
MLIR_CAPI_EXPORTED MlirContext
mlirRewritePatternSetGetContext(MlirRewritePatternSet set);
/// Destruct the given MlirRewritePatternSet.
MLIR_CAPI_EXPORTED void mlirRewritePatternSetDestroy(MlirRewritePatternSet set);

View File

@@ -11,6 +11,7 @@
#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRInterfaces.h"
@@ -246,6 +247,104 @@ public:
}
};
//===----------------------------------------------------------------------===//
// PatternDescriptorOpInterface
//===----------------------------------------------------------------------===//
class PyPatternDescriptorOpInterface
: public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
public:
using PyConcreteOpInterface<
PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirPatternDescriptorOpInterfaceTypeID;
/// Attach a new PatternDescriptorOpInterface FallbackModel to the named
/// operation. The FallbackModel acts as a trampoline for callbacks on the
/// Python class.
static void attach(nb::object &target, const std::string &opName,
DefaultingPyMlirContext ctx) {
// Prepare the callbacks that will be used by the FallbackModel.
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
// Make the pointer to the Python class available to the callbacks.
callbacks.userData = target.ptr();
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
// The above ref bump is all we need as initialization, no need to run the
// construct callback.
callbacks.construct = nullptr;
// Upon the FallbackModel's destruction, drop the ref to the Python class.
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
// The populatePatterns callback which calls into Python.
callbacks.populatePatterns =
[](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyPopulatePatterns =
nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
auto pyPatterns = PyRewritePatternSet(patterns);
// Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
// staticmethod.
MlirContext ctx = mlirOperationGetContext(op);
PyMlirContextRef context = PyMlirContext::forContext(ctx);
auto opview = PyOperation::forOperation(context, op)->createOpView();
pyPopulatePatterns(opview, pyPatterns);
};
// The populatePatternsWithState callback which calls into Python.
// Check if the Python class has populate_patterns_with_state method.
if (nb::hasattr(target, "populate_patterns_with_state")) {
callbacks.populatePatternsWithState = [](MlirOperation op,
MlirRewritePatternSet patterns,
MlirTransformState state,
void *userData) {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
nb::getattr(pyClass, "populate_patterns_with_state"));
auto pyPatterns = PyRewritePatternSet(patterns);
auto pyState = PyTransformState(state);
// Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
// state)` as a staticmethod.
MlirContext ctx = mlirOperationGetContext(op);
PyMlirContextRef context = PyMlirContext::forContext(ctx);
auto opview = PyOperation::forOperation(context, op)->createOpView();
pyPopulatePatternsWithState(opview, pyPatterns, pyState);
};
} else {
// Use default implementation (will call populatePatterns).
callbacks.populatePatternsWithState = nullptr;
}
// Attach a FallbackModel, which calls into Python, to the named operation.
mlirPatternDescriptorOpInterfaceAttachFallbackModel(
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
callbacks);
}
static void bindDerived(ClassTy &cls) {
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName, nb::object target,
DefaultingPyMlirContext context) {
if (target.is_none())
target = cls;
return attach(target, nb::cast<std::string>(opName), context);
},
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
nb::arg("target").none() = nb::none(),
nb::arg("context").none() = nb::none(),
"Attach the interface subclass to the given operation name.");
}
};
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -444,6 +543,7 @@ static void populateDialectTransformSubmodule(nb::module_ &m) {
PyTransformResults::bind(m);
PyTransformState::bind(m);
PyTransformOpInterface::bind(m);
PyPatternDescriptorOpInterface::bind(m);
m.def("only_reads_handle", onlyReadsHandle,
"Mark operands as only reading handles.", nb::arg("operands"),

View File

@@ -27,6 +27,18 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
// Note that None is considered success.
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
if (obj.is_none())
return mlirLogicalResultSuccess();
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
: mlirLogicalResultSuccess();
}
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
static constexpr const char *pyClassName = "PatternRewriter";
@@ -35,6 +47,70 @@ public:
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
//===----------------------------------------------------------------------===//
// PyRewritePatternSet
//===----------------------------------------------------------------------===//
PyRewritePatternSet::PyRewritePatternSet(MlirContext ctx)
: patterns(mlirRewritePatternSetCreate(ctx)), owned(true) {}
PyRewritePatternSet::PyRewritePatternSet(MlirRewritePatternSet patterns)
: patterns(patterns), owned(false) {}
PyRewritePatternSet::~PyRewritePatternSet() {
if (owned && patterns.ptr)
mlirRewritePatternSetDestroy(patterns);
}
MlirRewritePatternSet PyRewritePatternSet::get() const { return patterns; }
bool PyRewritePatternSet::isOwned() const { return owned; }
void PyRewritePatternSet::add(nb::handle root,
const nb::callable &matchAndRewrite,
unsigned benefit) {
std::string opName;
if (root.is_type()) {
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
} else if (nb::isinstance<nb::str>(root)) {
opName = nb::cast<std::string>(root);
} else {
throw nb::type_error("the root argument must be a type or a string");
}
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(context, op)->createOpView();
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, mlirRewritePatternSetGetContext(patterns), callbacks,
matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(patterns, pattern);
}
//===----------------------------------------------------------------------===//
// PyConversionPatternRewriter
//===----------------------------------------------------------------------===//
class PyConversionPatternRewriter : public PyPatternRewriter {
public:
PyConversionPatternRewriter(MlirConversionPatternRewriter rewriter)
@@ -132,6 +208,60 @@ private:
MlirConversionPattern pattern;
};
void PyRewritePatternSet::addConversion(nb::handle root,
const nb::callable &matchAndRewrite,
PyTypeConverter &typeConverter,
unsigned benefit) {
std::string opName;
if (root.is_type()) {
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
} else if (nb::isinstance<nb::str>(root)) {
opName = nb::cast<std::string>(root);
} else {
throw nb::type_error("the root argument must be a type or a string");
}
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirConversionPatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite =
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
MlirValue *operands, MlirConversionPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
nb::object adaptorCls =
PyGlobals::get()
.lookupOpAdaptorClass([&] {
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
return std::string_view(ref.data, ref.length);
}())
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
nb::object res = f(opView, adaptorCls(operandsVec, opView),
PyConversionPattern(pattern).getTypeConverter(),
PyConversionPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
rootName, benefit, mlirRewritePatternSetGetContext(patterns),
typeConverter.get(), callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(patterns,
mlirConversionPatternAsRewritePattern(pattern));
}
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
struct PyMlirPDLResultList : MlirPDLResultList {};
@@ -157,18 +287,6 @@ static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
return args;
}
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
// Note that None is considered success.
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
if (obj.is_none())
return mlirLogicalResultSuccess();
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
: mlirLogicalResultSuccess();
}
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
@@ -249,96 +367,55 @@ private:
MlirFrozenRewritePatternSet set;
};
class PyRewritePatternSet {
public:
PyRewritePatternSet(MlirContext ctx)
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
~PyRewritePatternSet() {
if (set.ptr)
mlirRewritePatternSetDestroy(set);
}
void PyRewritePatternSet::bind(nb::module_ &m) {
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
.def("add", &PyRewritePatternSet::add, nb::arg("root"), nb::arg("fn"),
nb::arg("benefit") = 1,
R"(Add a new rewrite pattern on the specified root operation, using
the provided callable for matching and rewriting, and assign it
the given benefit.
void add(MlirStringRef rootName, unsigned benefit,
const nb::callable &matchAndRewrite) {
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
Args:
root: The root operation to which this pattern applies. This may
be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes
an operation and a pattern rewriter. The match is considered
successful iff the callable returns a falsy value.
benefit: The benefit of the pattern, defaulting to 1.)")
.def("add_conversion", &PyRewritePatternSet::addConversion,
nb::arg("root"), nb::arg("fn"), nb::arg("type_converter"),
nb::arg("benefit") = 1,
R"(
Add a new conversion pattern on the specified root operation,
using the provided callable for matching and rewriting,
and assign it the given benefit.
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(set, pattern);
}
void addConversion(MlirStringRef rootName, unsigned benefit,
const nb::callable &matchAndRewrite,
PyTypeConverter &typeConverter) {
MlirConversionPatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite =
[](MlirConversionPattern pattern, MlirOperation op, intptr_t nOperands,
MlirValue *operands, MlirConversionPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
std::vector<MlirValue> operandsVec(operands, operands + nOperands);
nb::object adaptorCls =
PyGlobals::get()
.lookupOpAdaptorClass([&] {
MlirStringRef ref = mlirIdentifierStr(mlirOperationGetName(op));
return std::string_view(ref.data, ref.length);
}())
.value_or(nb::borrow(nb::type<PyOpAdaptor>()));
nb::object res = f(opView, adaptorCls(operandsVec, opView),
PyConversionPattern(pattern).getTypeConverter(),
PyConversionPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirConversionPattern pattern = mlirOpConversionPatternCreate(
rootName, benefit, ctx, typeConverter.get(), callbacks,
matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(set,
mlirConversionPatternAsRewritePattern(pattern));
}
PyFrozenRewritePatternSet freeze() {
MlirRewritePatternSet s = set;
set.ptr = nullptr;
return mlirFreezeRewritePattern(s);
}
private:
MlirRewritePatternSet set;
MlirContext ctx;
};
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes an
operation, its adaptor, the type converter and a pattern
rewriter. The match is considered successful iff the callable
returns a falsy value.
type_converter: The type converter to convert types in the IR.
benefit: The benefit of the pattern, defaulting to 1.)")
.def(
"freeze",
[](PyRewritePatternSet &self) {
if (!self.isOwned())
throw std::runtime_error(
"cannot freeze a non-owning pattern set");
MlirRewritePatternSet s = self.get();
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(s));
},
"Freeze the pattern set into a frozen one.");
}
enum class PyGreedyRewriteStrictness : std::underlying_type_t<
MlirGreedyRewriteStrictness> {
@@ -505,79 +582,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
.def(
"add",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
unsigned benefit) {
std::string opName;
if (root.is_type()) {
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
} else if (nb::isinstance<nb::str>(root)) {
opName = nb::cast<std::string>(root);
} else {
throw nb::type_error(
"the root argument must be a type or a string");
}
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
fn);
},
"root"_a, "fn"_a, "benefit"_a = 1,
// clang-format off
nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
// clang-format on
R"(
Add a new rewrite pattern on the specified root operation, using the provided callable
for matching and rewriting, and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
an operation name string (e.g., ``"arith.addi"``).
fn: The callable to use for matching and rewriting,
which takes an operation and a pattern rewriter as arguments.
The match is considered successful iff the callable returns
a value where ``bool(value)`` is ``False`` (e.g. ``None``).
If possible, the operation is cast to its corresponding OpView subclass
before being passed to the callable.
benefit: The benefit of the pattern, defaulting to 1.)")
.def(
"add_conversion",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
PyTypeConverter &typeConverter, unsigned benefit) {
std::string opName =
nb::cast<std::string>(root.attr("OPERATION_NAME"));
self.addConversion(
mlirStringRefCreate(opName.data(), opName.size()), benefit, fn,
typeConverter);
},
"root"_a, "fn"_a, "type_converter"_a, "benefit"_a = 1,
R"(
Add a new conversion pattern on the specified root operation,
using the provided callable for matching and rewriting,
and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
an operation name string (e.g., ``"arith.addi"``).
fn: The callable to use for matching and rewriting,
which takes an operation, its adaptor,
the type converter and a pattern rewriter as arguments.
The match is considered successful iff the callable returns
a value where ``bool(value)`` is ``False`` (e.g. ``None``).
If possible, the operation is cast to its corresponding OpView subclass
before being passed to the callable.
type_converter: The type converter to convert types in the IR.
benefit: The benefit of the pattern, defaulting to 1.)")
.def("freeze", &PyRewritePatternSet::freeze,
"Freeze the pattern set into a frozen one.");
PyRewritePatternSet::bind(m);
nb::class_<PyConversionPatternRewriter, PyPatternRewriter>(
m, "ConversionPatternRewriter")

View File

@@ -75,6 +75,41 @@ private:
PyMlirContextRef ctx;
};
/// Wrapper around MlirRewritePatternSet.
/// The default constructor creates an owned pattern set that is destroyed
/// in the destructor. The constructor taking MlirRewritePatternSet creates
/// a non-owning reference.
class PyTypeConverter;
class MLIR_PYTHON_API_EXPORTED PyRewritePatternSet {
public:
/// Create an owned pattern set.
PyRewritePatternSet(MlirContext ctx);
/// Create a non-owning reference to an existing pattern set.
PyRewritePatternSet(MlirRewritePatternSet patterns);
~PyRewritePatternSet();
MlirRewritePatternSet get() const;
bool isOwned() const;
/// Add a new rewrite pattern to the pattern set.
void add(nanobind::handle root, const nanobind::callable &matchAndRewrite,
unsigned benefit);
/// Add a new conversion pattern to the pattern set.
void addConversion(nanobind::handle root,
const nanobind::callable &matchAndRewrite,
PyTypeConverter &typeConverter, unsigned benefit);
static void bind(nanobind::module_ &m);
private:
MlirRewritePatternSet patterns;
bool owned;
};
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python

View File

@@ -298,6 +298,89 @@ void mlirTransformOpInterfaceAttachFallbackModel(
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// PatternDescriptorOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
}
/// Fallback model for the PatternDescriptorOpInterface that uses C API
/// callbacks.
class PatternDescriptorOpInterfaceFallbackModel
: public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
PatternDescriptorOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~PatternDescriptorOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return transform::PatternDescriptorOpInterface::getInterfaceID();
}
static bool
classof(const mlir::transform::detail::
PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
assert(callbacks.populatePatterns && "populatePatterns callback not set");
callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
}
void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
transform::TransformState &state) const {
if (callbacks.populatePatternsWithState) {
callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
wrap(&state), callbacks.userData);
} else {
// Default implementation: call populatePatterns without state.
populatePatterns(op, patterns);
}
}
private:
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
};
/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
/// operation. The FallbackModel uses the provided callbacks to implement the
/// interface.
void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context.
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//

View File

@@ -728,6 +728,10 @@ MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
return wrap(new mlir::RewritePatternSet(unwrap(context)));
}
MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) {
return wrap(unwrap(set)->getContext());
}
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
delete unwrap(set);
}

View File

@@ -0,0 +1,114 @@
# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
from contextlib import contextmanager
from mlir import ir, rewrite
from mlir.dialects import transform, func, arith, ext
from mlir.dialects.transform import AnyOpType, structured
@ext.register_dialect
class MyPatternDescriptors(ext.Dialect, name="my_pattern_descriptors"):
pass
def run(emit_schedule):
print(f"Test: {emit_schedule.__name__}")
with ir.Context(), ir.Location.unknown():
payload = emit_payload()
MyPatternDescriptors.load(register=False, reload=True)
# NB: Pattern descriptor ops have their interfaces attached
# in their respective test functions.
schedule = emit_schedule()
(_named_seq := schedule.body.operations[0]).apply(payload)
print(payload)
# Payload used by all tests.
def emit_payload():
payload_module = ir.Module.create()
with ir.InsertionPoint(payload_module.body):
i32 = ir.IntegerType.get_signless(32)
@func.FuncOp.from_py_func(i32, i32)
def test_func(a, b):
c = arith.addi(a, b)
d = arith.subi(c, b)
return d
return payload_module
@contextmanager
def schedule_boilerplate():
schedule = ir.Module.create()
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
with ir.InsertionPoint(schedule.body):
named_sequence = transform.NamedSequenceOp(
"__transform_main",
[AnyOpType.get()],
[AnyOpType.get()],
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
)
with ir.InsertionPoint(named_sequence.body):
yield schedule, named_sequence
@ext.register_operation(MyPatternDescriptors)
class SubiAddiRewritePatternOp(MyPatternDescriptors.Operation, name="add_pattern"):
@classmethod
def attach_interface_impls(cls, ctx=None):
cls.PatternDescriptorOpInterfaceFallbackModel.attach(
cls.OPERATION_NAME, context=ctx
)
class PatternDescriptorOpInterfaceFallbackModel(
transform.PatternDescriptorOpInterface
):
@staticmethod
def populate_patterns(
op: "SubiAddiRewritePatternOp",
patterns: rewrite.RewritePatternSet,
) -> None:
# Define a pattern that rewrites subi(addi(a, b), b) -> a
def match_and_rewrite(subi, rewriter):
if not isinstance(addi := subi.lhs.owner, arith.AddiOp):
return True # Failed match, return truthy value
if subi.rhs != addi.rhs:
return True
# Replace subi's result with addi's lhs
rewriter.replace_op(subi, [addi.lhs])
return None # Success
# Add the pattern to the pattern set.
patterns.add("arith.subi", match_and_rewrite, benefit=1)
# CHECK-LABEL: Test: test_pattern_descriptor_add_pattern
@run
def test_pattern_descriptor_add_pattern():
"""Tests python-defined rewrite pattern via PatternDescriptorOpInterface on AddPatternOp"""
SubiAddiRewritePatternOp.attach_interface_impls()
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
# After pattern application, check that subi is removed and func returns
# the first argument directly:
# CHECK: func.func @test_func(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32)
# CHECK: return %[[ARG0]] : i32
apply_patterns_op = transform.ApplyPatternsOp(func_handle)
with ir.InsertionPoint(apply_patterns_op.patterns):
SubiAddiRewritePatternOp()
transform.yield_([func_handle])
named_seq.verify()
return schedule