[MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (#176920)

Provides the infrastructure for implementing and late-binding
OpInterfaces from Python.

* On the mlir-c API declaration side, each `XOpInterface` has a callback
struct, with a callback for each method and a userdata member (provided
as an arg to each method), and a
`mlirXOpInterfaceAttachFallbackModel(ctx, op_name, callbacks)` func.
* This CAPI is implemented by defining a subclass of
`XOpInterface::FallbackModel` that holds the callback struct and has
each method call the corresponding callback (with userdata as an arg).
Given a callback struct, a new `FallbackModel` is created and attached,
i.e. late bound, to the named op. (MLIR's interface infrastructure is
such that the thus registered `FallbackModel` will be returned in case
the op gets cast to the `XOpInterface`.)
* On the Python side, we expose a stand-in `XOpInterface` base class
which has one (class)method: `XOpInterface.attach(cls, op_name, ctx)`.
Python users subclass this class (`class MyInterfaceImpl(XOpInterface):
...`) and implement the interface's methods (with the right names and
signatures). The user calls `attach` on the subclass
(`MyInterfaceImpl.attach("my_dialect.my_op", ctx)`) which prepares the
callbacks struct _with userdata set to the subclass_ (as we use it to
lookup methods). These callbacks (and userdata) are then registered as
an `XOpInterface::FallbackModel` by
`mlirXOpInterfaceAttachFallbackModel(...)`. From then on the Python
methods will be used to respond to calls to the interface methods
(originating in C++).

This PR enables implementing the TransformOpInterface and the
MemoryEffectsOpInterface, both of which are required for making an op
into a transform op.

Everything besides the above linked code is there to facilitate exposing
the interfaces: the right types for the arguments of the methods are
exposed as are functions/methods for manipulating these arguments (e.g.
specifying side effects on `OpOperand`s and `OpResult`s and being able
to access and set the transform handles associated with args and
results).
This commit is contained in:
Rolf Morel
2026-02-12 15:07:10 +01:00
committed by GitHub
parent 8e1d5ec534
commit a1d7cda1d7
19 changed files with 1674 additions and 215 deletions

View File

@@ -11,6 +11,8 @@
#define MLIR_C_DIALECT_TRANSFORM_H
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
@@ -19,6 +21,32 @@ extern "C" {
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(MlirTransformResults, void);
DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
DEFINE_C_API_STRUCT(MlirTransformState, void);
#undef DEFINE_C_API_STRUCT
//===---------------------------------------------------------------------===//
// DiagnosedSilenceableFailure
//===---------------------------------------------------------------------===//
/// Enum representing the result of a transform operation.
typedef enum {
/// The operation succeeded.
MlirDiagnosedSilenceableFailureSuccess,
/// The operation failed in a silenceable way.
MlirDiagnosedSilenceableFailureSilenceableFailure,
/// The operation failed definitively.
MlirDiagnosedSilenceableFailureDefiniteFailure
} MlirDiagnosedSilenceableFailure;
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
@@ -86,6 +114,126 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
//===---------------------------------------------------------------------===//
// TransformRewriter
//===---------------------------------------------------------------------===//
/// Cast the TransformRewriter to a RewriterBase
MLIR_CAPI_EXPORTED MlirRewriterBase
mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
//===---------------------------------------------------------------------===//
// TransformResults
//===---------------------------------------------------------------------===//
/// Set the payload operations for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
MlirValue result,
intptr_t numOps,
MlirOperation *ops);
/// Set the payload values for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED void
mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
intptr_t numValues, MlirValue *values);
/// Set the parameters for a transform result by iterating over a list.
MLIR_CAPI_EXPORTED void
mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
intptr_t numParams, MlirAttribute *params);
//===---------------------------------------------------------------------===//
// TransformState
//===---------------------------------------------------------------------===//
/// Callback for iterating over payload operations.
typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
/// Iterate over payload operations associated with the transform IR value.
/// Calls the callback for each payload operation.
MLIR_CAPI_EXPORTED void
mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
MlirOperationCallback callback,
void *userData);
/// Callback for iterating over payload values.
typedef void (*MlirValueCallback)(MlirValue, void *userData);
/// Iterate over payload values associated with the transform IR value.
/// Calls the callback for each payload value.
MLIR_CAPI_EXPORTED void
mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
MlirValueCallback callback,
void *userData);
/// Callback for iterating over parameters.
typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
/// Iterate over parameters associated with the transform IR value.
/// Calls the callback for each parameter.
MLIR_CAPI_EXPORTED void
mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
MlirAttributeCallback callback, void *userData);
//===---------------------------------------------------------------------===//
// TransformOpInterface
//===---------------------------------------------------------------------===//
/// Returns the interface TypeID of the TransformOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
/// Callbacks for implementing TransformOpInterface from external code.
typedef struct {
/// Optional constructor for the user data.
/// Set to nullptr to disable it.
void (*construct)(void *userData);
/// Optional destructor for the user data.
/// Set to nullptr to disable it.
void (*destruct)(void *userData);
/// Apply callback that implements the transformation.
MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
MlirTransformRewriter rewriter,
MlirTransformResults results,
MlirTransformState state,
void *userData);
/// Callback to check if repeated handle operands are allowed.
bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
void *userData;
} MlirTransformOpInterfaceCallbacks;
/// Attach TransformOpInterface to the operation with the given name using
/// the provided callbacks.
MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks);
//===---------------------------------------------------------------------===//
// Transform-specifc MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
/// Helper to mark operands as only reading handles.
MLIR_CAPI_EXPORTED void
mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects);
/// Helper to mark operands as consuming handles.
MLIR_CAPI_EXPORTED void
mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects);
/// Helper to mark results as producing handles.
MLIR_CAPI_EXPORTED void
mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
MlirMemoryEffectInstancesList effects);
/// Helper to mark potential modifications to the payload IR.
MLIR_CAPI_EXPORTED void
mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects);
/// Helper to mark potential reads from the payload IR.
MLIR_CAPI_EXPORTED void
mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects);
#ifdef __cplusplus
}
#endif

View File

@@ -673,6 +673,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
intptr_t pos);
/// Returns `pos`-th OpOperand of the operation.
MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
intptr_t pos);
/// Sets the `pos`-th operand of the operation.
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue);

View File

@@ -22,6 +22,16 @@
extern "C" {
#endif
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
#undef DEFINE_C_API_STRUCT
/// Returns `true` if the given operation implements an interface identified by
/// its TypeID.
MLIR_CAPI_EXPORTED bool
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferTypeOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple types from functions while
/// transferring ownership to the caller. The first argument is the number of
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferShapedTypeOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
/// These callbacks are used to return multiple shaped type components from
/// functions while transferring ownership to the caller. The first argument is
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
void *properties, intptr_t nRegions, MlirRegion *regions,
MlirShapedTypeComponentsCallback callback, void *userData);
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface
//===---------------------------------------------------------------------===//
/// Returns the interface TypeID of the MemoryEffectsOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
/// Callbacks for implementing MemoryEffectsOpInterface from external code.
typedef struct {
/// Optional constructor for user data. Set to nullptr to disable it.
void (*construct)(void *userData);
/// Optional destructor for user data. Set to nullptr to disable it.
void (*destruct)(void *userData);
/// Get memory effects callback.
void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
void *userData);
void *userData;
} MlirMemoryEffectsOpInterfaceCallbacks;
/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
/// operation. The FallbackModel will call the provided callbacks.
MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirMemoryEffectsOpInterfaceCallbacks callbacks);
#ifdef __cplusplus
}
#endif

View File

@@ -1492,6 +1492,7 @@ private:
class MLIR_PYTHON_API_EXPORTED PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
operator MlirOpOperand() const { return opOperand; }
nanobind::typed<nanobind::object, PyOpView> getOwner() const;
@@ -1871,13 +1872,20 @@ public:
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
inline nanobind::object classmethod(Func f, Args... args) {
nanobind::object cf = nanobind::cpp_function(f, args...);
return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
namespace nanobind {
namespace detail {
template <>
struct type_caster<
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>

View File

@@ -0,0 +1,28 @@
//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of implementation details of the C API for
// the Transform dialect. This file should not be included from C++ code other
// than C API implementation nor from C code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
#define MLIR_CAPI_DIALECT_TRANSFORM_H
#include "mlir-c/Dialect/Transform.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
mlir::transform::TransformRewriter)
DEFINE_C_API_PTR_METHODS(MlirTransformResults,
mlir::transform::TransformResults)
DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
#endif // MLIR_CAPI_DIALECT_TRANSFORM_H

View File

@@ -15,4 +15,12 @@
#ifndef MLIR_CAPI_INTERFACES_H
#define MLIR_CAPI_INTERFACES_H
#include "mlir-c/Interfaces.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
DEFINE_C_API_PTR_METHODS(
MlirMemoryEffectInstancesList,
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
#endif // MLIR_CAPI_INTERFACES_H

View File

@@ -8,12 +8,14 @@
#include <string>
#include "IRInterfaces.h"
#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
#include <nanobind/trampoline.h>
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -22,6 +24,227 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
//===----------------------------------------------------------------------===//
// TransformRewriter
//===----------------------------------------------------------------------===//
class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
public:
static constexpr const char *pyClassName = "TransformRewriter";
PyTransformRewriter(MlirTransformRewriter rewriter)
: PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
};
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
class PyTransformResults {
public:
PyTransformResults(MlirTransformResults results) : results(results) {}
MlirTransformResults get() const { return results; }
void setOps(PyValue &result, const nb::list &ops) {
std::vector<MlirOperation> opsVec;
opsVec.reserve(ops.size());
for (auto op : ops) {
opsVec.push_back(nb::cast<MlirOperation>(op));
}
mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
}
void setValues(PyValue &result, const nb::list &values) {
std::vector<MlirValue> valuesVec;
valuesVec.reserve(values.size());
for (auto item : values) {
valuesVec.push_back(nb::cast<MlirValue>(item));
}
mlirTransformResultsSetValues(results, result, valuesVec.size(),
valuesVec.data());
}
void setParams(PyValue &result, const nb::list &params) {
std::vector<MlirAttribute> paramsVec;
paramsVec.reserve(params.size());
for (auto item : params) {
paramsVec.push_back(nb::cast<MlirAttribute>(item));
}
mlirTransformResultsSetParams(results, result, paramsVec.size(),
paramsVec.data());
}
static void bind(nanobind::module_ &m) {
nb::class_<PyTransformResults>(m, "TransformResults")
.def(nb::init<MlirTransformResults>())
.def("set_ops", &PyTransformResults::setOps,
"Set the payload operations for a transform result.",
nb::arg("result"), nb::arg("ops"))
.def("set_values", &PyTransformResults::setValues,
"Set the payload values for a transform result.",
nb::arg("result"), nb::arg("values"))
.def("set_params", &PyTransformResults::setParams,
"Set the parameters for a transform result.", nb::arg("result"),
nb::arg("params"));
}
private:
MlirTransformResults results;
};
//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
class PyTransformState {
public:
PyTransformState(MlirTransformState state) : state(state) {}
MlirTransformState get() const { return state; }
static void bind(nanobind::module_ &m) {
nb::class_<PyTransformState>(m, "TransformState")
.def(nb::init<MlirTransformState>())
.def("get_payload_ops", &PyTransformState::getPayloadOps,
"Get the payload operations associated with a transform IR value.",
nb::arg("operand"))
.def("get_payload_values", &PyTransformState::getPayloadValues,
"Get the payload values associated with a transform IR value.",
nb::arg("operand"))
.def("get_params", &PyTransformState::getParams,
"Get the parameters (attributes) associated with a transform IR "
"value.",
nb::arg("operand"));
}
private:
nanobind::list getPayloadOps(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachPayloadOp(
state, value,
[](MlirOperation op, void *userData) {
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
static_cast<nanobind::list *>(userData)->append(opview);
},
&result);
return result;
}
nanobind::list getPayloadValues(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachPayloadValue(
state, value,
[](MlirValue val, void *userData) {
static_cast<nanobind::list *>(userData)->append(val);
},
&result);
return result;
}
nanobind::list getParams(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachParam(
state, value,
[](MlirAttribute attr, void *userData) {
static_cast<nanobind::list *>(userData)->append(attr);
},
&result);
return result;
}
MlirTransformState state;
};
//===----------------------------------------------------------------------===//
// TransformOpInterface
//===----------------------------------------------------------------------===//
class PyTransformOpInterface
: public PyConcreteOpInterface<PyTransformOpInterface> {
public:
using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "TransformOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirTransformOpInterfaceTypeID;
/// Attach a new TransformOpInterface FallbackModel to the named operation.
/// The FallbackModel acts as a trampoline for callbacks on the Python class.
static void attach(nb::object &target, const std::string &opName,
DefaultingPyMlirContext ctx) {
// Prepare the callbacks that will be used by the FallbackModel.
MlirTransformOpInterfaceCallbacks callbacks;
// Make the pointer to the Python class available to the callbacks.
callbacks.userData = target.ptr();
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
// The above ref bump is all we need as initialization, no need to run the
// construct callback.
callbacks.construct = nullptr;
// Upon the FallbackModel's destruction, drop the ref to the Python class.
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
// The apply callback which calls into Python.
callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
MlirTransformResults results, MlirTransformState state,
void *userData) -> MlirDiagnosedSilenceableFailure {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
auto pyRewriter = PyTransformRewriter(rewriter);
auto pyResults = PyTransformResults(results);
auto pyState = PyTransformState(state);
// Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
// staticmethod.
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
return nb::cast<MlirDiagnosedSilenceableFailure>(res);
};
// The allows_repeated_handle_operands callback which calls into Python.
callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
void *userData) -> bool {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
nb::getattr(pyClass, "allow_repeated_handle_operands"));
// Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
// staticmethod.
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
nb::object res = pyAllowRepeatedHandleOperands(opview);
return nb::cast<bool>(res);
};
// Attach a FallbackModel, which calls into Python, to the named operation.
mlirTransformOpInterfaceAttachFallbackModel(
ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
}
static void bindDerived(ClassTy &cls) {
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName, nb::object target,
DefaultingPyMlirContext context) {
if (target.is_none())
target = cls;
return attach(target, nb::cast<std::string>(opName), context);
},
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
nb::arg("target").none() = nb::none(),
nb::arg("context").none() = nb::none(),
"Attach the interface subclass to the given operation name.");
}
};
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
@@ -162,12 +385,81 @@ struct ParamType : PyConcreteType<ParamType> {
}
};
//===----------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===----------------------------------------------------------------------===//
namespace {
void onlyReadsHandle(nb::iterable &operands,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
effects.effects);
};
void consumesHandle(nb::iterable &operands,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
effects.effects);
};
void producesHandle(nb::iterable &results,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirValue> resultsVec;
for (auto result : results)
resultsVec.push_back(nb::cast<PyOpResult>(result).get());
mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
effects.effects);
};
void modifiesPayload(PyMemoryEffectsInstanceList effects) {
mlirTransformModifiesPayload(effects.effects);
}
void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
mlirTransformOnlyReadsPayload(effects.effects);
}
} // namespace
static void populateDialectTransformSubmodule(nb::module_ &m) {
nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
.value("Success", MlirDiagnosedSilenceableFailureSuccess)
.value("SilenceableFailure",
MlirDiagnosedSilenceableFailureSilenceableFailure)
.value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
AnyOpType::bind(m);
AnyParamType::bind(m);
AnyValueType::bind(m);
OperationType::bind(m);
ParamType::bind(m);
PyTransformRewriter::bind(m);
PyTransformResults::bind(m);
PyTransformState::bind(m);
PyTransformOpInterface::bind(m);
m.def("only_reads_handle", onlyReadsHandle,
"Mark operands as only reading handles.", nb::arg("operands"),
nb::arg("effects"));
m.def("consumes_handle", consumesHandle,
"Mark operands as consuming handles.", nb::arg("operands"),
nb::arg("effects"));
m.def("produces_handle", producesHandle, "Mark results as producing handles.",
nb::arg("results"), nb::arg("effects"));
m.def("modifies_payload", modifiesPayload,
"Mark the transform as modifying the payload.", nb::arg("effects"));
m.def("only_reads_payload", onlyReadsPayload,
"Mark the transform as only reading the payload.", nb::arg("effects"));
}
} // namespace transform
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN

View File

@@ -10,7 +10,6 @@
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir-c/BuiltinAttributes.h"
@@ -57,13 +56,6 @@ static size_t hash(const T &value) {
return std::hash<T>{}(value);
}
/// Helper for creating an @classmethod.
template <class Func, typename... Args>
static nb::object classmethod(Func f, Args... args) {
nb::object cf = nb::cpp_function(f, args...);
return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
}
static nb::object
createCustomDialectWrapper(const std::string &dialectNamespace,
nb::object dialectDescriptor) {
@@ -2289,6 +2281,44 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
return PyOpOperandList(operation, startIndex, length, step);
}
/// A list of OpOperands. Internally, these are stored as consecutive elements,
/// random access is cheap. The (returned) OpOperand list is associated with the
/// operation whose operands these are, and thus extends the lifetime of this
/// operation.
class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
public:
static constexpr const char *pyClassName = "OpOperands";
using SliceableT = Sliceable<PyOpOperandList, PyOpOperand>;
PyOpOperands(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirOperationGetNumOperands(operation->get())
: length,
step),
operation(operation) {}
private:
/// Give the parent CRTP class access to hook implementations below.
friend class Sliceable<PyOpOperands, PyOpOperand>;
intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumOperands(operation->get());
}
PyOpOperand getRawElement(intptr_t pos) {
MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
return PyOpOperand(opOperand);
}
PyOpOperands slice(intptr_t startIndex, intptr_t length, intptr_t step) {
return PyOpOperands(operation, startIndex, length, step);
}
PyOperationRef operation;
};
PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
intptr_t length, intptr_t step)
: Sliceable(startIndex,
@@ -3669,6 +3699,12 @@ void populateIRCore(nb::module_ &m) {
return PyOpOperandList(self.getOperation().getRef());
},
"Returns the list of operation operands.")
.def_prop_ro(
"op_operands",
[](PyOperationBase &self) {
return PyOpOperands(self.getOperation().getRef());
},
"Returns the list of op operands.")
.def_prop_ro(
"regions",
[](PyOperationBase &self) {
@@ -4950,6 +4986,7 @@ void populateIRCore(nb::module_ &m) {
PyOpAttributeMap::bind(m);
PyOpOperandIterator::bind(m);
PyOpOperandList::bind(m);
PyOpOperands::bind(m);
PyOpResultList::bind(m);
PyOpSuccessors::bind(m);
PyRegionIterator::bind(m);

View File

@@ -12,30 +12,18 @@
#include <utility>
#include <vector>
#include "IRInterfaces.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
interface.)";
constexpr static const char *operationDoc =
R"(Returns an Operation for which the interface was constructed.)";
constexpr static const char *opviewDoc =
R"(Returns an OpView subclass _instance_ for which the interface was
constructed)";
constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
@@ -124,119 +112,6 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
} // namespace
/// CRTP base class for Python classes representing MLIR Op interfaces.
/// Interface hierarchies are flat so no base class is expected here. The
/// derived class is expected to define the following static fields:
/// - `const char *pyClassName` - the name of the Python class to create;
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
/// of the interface.
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
/// interface-specific methods.
///
/// An interface class may be constructed from either an Operation/OpView object
/// or from a subclass of OpView. In the latter case, only the static interface
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
/// method to check whether the interface object was constructed from a class or
/// an operation/opview instance. The `getOpName` always succeeds and returns a
/// canonical name of the operation suitable for lookups.
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
using ClassTy = nb::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
/// Constructs an interface instance from an object that is either an
/// operation or a subclass of OpView. In the latter case, only the static
/// methods of the interface are accessible to the caller.
PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
: obj(std::move(object)) {
try {
operation = &nb::cast<PyOperation &>(obj);
} catch (nb::cast_error &) {
// Do nothing.
}
try {
operation = &nb::cast<PyOpView &>(obj).getOperation();
} catch (nb::cast_error &) {
// Do nothing.
}
if (operation != nullptr) {
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
MlirStringRef stringRef = mlirIdentifierStr(identifier);
opName = std::string(stringRef.data, stringRef.length);
} else {
try {
opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
} catch (nb::cast_error &) {
throw nb::type_error(
"Op interface does not refer to an operation or OpView class");
}
if (!mlirOperationImplementsInterfaceStatic(
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
}
}
/// Creates the Python bindings for this class in the given module.
static void bind(nb::module_ &m) {
nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
nb::arg("context") = nb::none(), constructorDoc)
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
ConcreteIface::bindDerived(cls);
}
/// Hook for derived classes to add class-specific bindings.
static void bindDerived(ClassTy &cls) {}
/// Returns `true` if this object was constructed from a subclass of OpView
/// rather than from an operation instance.
bool isStatic() { return operation == nullptr; }
/// Returns the operation instance from which this object was constructed.
/// Throws a type error if this object was constructed from a subclass of
/// OpView.
nb::typed<nb::object, PyOperation> getOperationObject() {
if (operation == nullptr)
throw nb::type_error("Cannot get an operation from a static interface");
return operation->getRef().releaseObject();
}
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
nb::typed<nb::object, PyOpView> getOpView() {
if (operation == nullptr)
throw nb::type_error("Cannot get an opview from a static interface");
return operation->createOpView();
}
/// Returns the canonical name of the operation this interface is constructed
/// from.
const std::string &getOpName() { return opName; }
private:
PyOperation *operation = nullptr;
std::string opName;
nb::object obj;
};
/// Python wrapper for InferTypeOpInterface. This interface has only static
/// methods.
class PyInferTypeOpInterface
@@ -462,10 +337,74 @@ public:
}
};
/// Wrapper around the MemoryEffectsOpInterface.
class PyMemoryEffectsOpInterface
: public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
public:
using PyConcreteOpInterface<
PyMemoryEffectsOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirMemoryEffectsOpInterfaceTypeID;
/// Attach a new MemoryEffectsOpInterface FallbackModel to the named
/// operation. The FallbackModel acts as a trampoline for callbacks on the
/// Python class.
static void attach(nb::object &target, const std::string &opName,
DefaultingPyMlirContext ctx) {
MlirMemoryEffectsOpInterfaceCallbacks callbacks;
callbacks.userData = target.ptr();
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
callbacks.construct = nullptr;
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.getEffects = [](MlirOperation op,
MlirMemoryEffectInstancesList effects,
void *userData) {
nb::handle pyClass(static_cast<PyObject *>(userData));
// Get the 'get_effects' method from the Python class.
auto pyGetEffects =
nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
PyMemoryEffectsInstanceList effectsWrapper{effects};
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
// Invoke `pyClass.get_effects(op, effects)`.
pyGetEffects(opview, effectsWrapper);
};
mlirMemoryEffectsOpInterfaceAttachFallbackModel(
ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
}
static void bindDerived(ClassTy &cls) {
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName, nb::object target,
DefaultingPyMlirContext context) {
if (target.is_none())
target = cls;
return attach(target, nb::cast<std::string>(opName), context);
},
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
nb::arg("target").none() = nb::none(),
nb::arg("context").none() = nb::none(),
"Attach the interface subclass to the given operation name.");
}
};
void populateIRInterfaces(nb::module_ &m) {
PyInferTypeOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
PyInferShapedTypeOpInterface::bind(m);
PyInferTypeOpInterface::bind(m);
PyMemoryEffectsOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python

View File

@@ -0,0 +1,152 @@
//===- IRInterfaces.h - IR Interfaces for Python Bindings -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BINDINGS_PYTHON_IRINTERFACES_H
#define MLIR_BINDINGS_PYTHON_IRINTERFACES_H
#include "mlir-c/IR.h"
#include "mlir-c/Interfaces.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include <nanobind/nanobind.h>
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
interface.)";
constexpr static const char *operationDoc =
R"(Returns an Operation for which the interface was constructed.)";
constexpr static const char *opviewDoc =
R"(Returns an OpView subclass _instance_ for which the interface was
constructed)";
/// CRTP base class for Python classes representing MLIR Op interfaces.
/// Interface hierarchies are flat so no base class is expected here. The
/// derived class is expected to define the following static fields:
/// - `const char *pyClassName` - the name of the Python class to create;
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
/// of the interface.
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
/// interface-specific methods.
///
/// An interface class may be constructed from either an Operation/OpView object
/// or from a subclass of OpView. In the latter case, only the static interface
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
/// method to check whether the interface object was constructed from a class or
/// an operation/opview instance. The `getOpName` always succeeds and returns a
/// canonical name of the operation suitable for lookups.
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
using ClassTy = nanobind::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
/// Constructs an interface instance from an object that is either an
/// operation or a subclass of OpView. In the latter case, only the static
/// methods of the interface are accessible to the caller.
PyConcreteOpInterface(nanobind::object object,
DefaultingPyMlirContext context)
: obj(std::move(object)) {
if (!nanobind::try_cast<PyOperation *>(obj, operation)) {
PyOpView *opview;
if (nanobind::try_cast<PyOpView *>(obj, opview)) {
operation = &opview->getOperation();
};
}
if (operation != nullptr) {
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw nanobind::value_error((msg + ConcreteIface::pyClassName).c_str());
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
MlirStringRef stringRef = mlirIdentifierStr(identifier);
opName = std::string(stringRef.data, stringRef.length);
} else {
if (!nanobind::try_cast<std::string>(obj.attr("OPERATION_NAME"), opName))
throw nanobind::type_error(
"Op interface does not refer to an operation or OpView class");
if (!mlirOperationImplementsInterfaceStatic(
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw nanobind::value_error((msg + ConcreteIface::pyClassName).c_str());
}
}
}
/// Creates the Python bindings for this class in the given module.
static void bind(nanobind::module_ &m) {
nanobind::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
cls.def(nanobind::init<nanobind::object, DefaultingPyMlirContext>(),
nanobind::arg("object"),
nanobind::arg("context") = nanobind::none(), constructorDoc)
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
ConcreteIface::bindDerived(cls);
}
/// Hook for derived classes to add class-specific bindings.
static void bindDerived(ClassTy &cls) {}
/// Returns `true` if this object was constructed from a subclass of OpView
/// rather than from an operation instance.
bool isStatic() { return operation == nullptr; }
/// Returns the operation instance from which this object was constructed.
/// Throws a type error if this object was constructed from a subclass of
/// OpView.
nanobind::typed<nanobind::object, PyOperation> getOperationObject() {
if (operation == nullptr)
throw nanobind::type_error(
"Cannot get an operation from a static interface");
return operation->getRef().releaseObject();
}
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
nanobind::typed<nanobind::object, PyOpView> getOpView() {
if (operation == nullptr)
throw nanobind::type_error(
"Cannot get an opview from a static interface");
return operation->createOpView();
}
/// Returns the canonical name of the operation this interface is constructed
/// from.
const std::string &getOpName() { return opName; }
private:
PyOperation *operation = nullptr;
std::string opName;
nanobind::object obj;
};
struct PyMemoryEffectsInstanceList {
MlirMemoryEffectInstancesList effects;
};
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_IRINTERFACES_H

View File

@@ -8,15 +8,12 @@
#include "Rewrite.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
#include <type_traits>
@@ -30,38 +27,12 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyPatternRewriter {
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
static constexpr const char *pyClassName = "PatternRewriter";
PyPatternRewriter(MlirPatternRewriter rewriter)
: base(mlirPatternRewriterAsBase(rewriter)),
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
PyInsertionPoint getInsertionPoint() const {
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
if (mlirOperationIsNull(op)) {
MlirOperation owner = mlirBlockGetParentOperation(block);
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
void replaceOp(MlirOperation op, MlirOperation newOp) {
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
}
void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
}
void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
};
class PyConversionPatternRewriter : PyPatternRewriter {
@@ -514,29 +485,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
.def(
"replace_op",
[](PyPatternRewriter &self, PyOperationBase &op,
PyOperationBase &newOp) {
self.replaceOp(op.getOperation(), newOp.getOperation());
},
"Replace an operation with a new operation.", nb::arg("op"),
nb::arg("new_op"))
.def(
"replace_op",
[](PyPatternRewriter &self, PyOperationBase &op,
const std::vector<PyValue> &values) {
std::vector<MlirValue> values_(values.size());
std::copy(values.begin(), values.end(), values_.begin());
self.replaceOp(op.getOperation(), values_);
},
"Replace an operation with a list of values.", nb::arg("op"),
nb::arg("values"))
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
nb::arg("op"));
PyPatternRewriter::bind(m);
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet

View File

@@ -9,13 +9,74 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
#include "mlir/Bindings/Python/NanobindUtils.h"
#include "mlir-c/Rewrite.h"
#include "mlir/Bindings/Python/IRCore.h"
#include <nanobind/nanobind.h>
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
void populateRewriteSubmodule(nanobind::module_ &m);
}
/// CRTP Base class for rewriter wrappers.
template <typename DerivedTy>
class MLIR_PYTHON_API_EXPORTED PyRewriterBase {
public:
PyRewriterBase(MlirRewriterBase rewriter)
: base(rewriter),
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
PyInsertionPoint getInsertionPoint() const {
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
if (mlirOperationIsNull(op)) {
MlirOperation owner = mlirBlockGetParentOperation(block);
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
static void bind(nanobind::module_ &m) {
nanobind::class_<DerivedTy>(m, DerivedTy::pyClassName)
.def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
.def(
"replace_op",
[](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
mlirRewriterBaseReplaceOpWithOperation(
self.base, op.getOperation(), newOp.getOperation());
},
"Replace an operation with a new operation.", nanobind::arg("op"),
nanobind::arg("new_op"))
.def(
"replace_op",
[](DerivedTy &self, PyOperationBase &op,
const std::vector<PyValue> &values) {
std::vector<MlirValue> values_(values.size());
std::copy(values.begin(), values.end(), values_.begin());
mlirRewriterBaseReplaceOpWithValues(
self.base, op.getOperation(), values_.size(), values_.data());
},
"Replace an operation with a list of values.", nanobind::arg("op"),
nanobind::arg("values"))
.def(
"erase_op",
[](DerivedTy &self, PyOperationBase &op) {
mlirRewriterBaseEraseOp(self.base, op.getOperation());
},
"Erase an operation.", nanobind::arg("op"));
}
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
};
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir

View File

@@ -8,9 +8,14 @@
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Dialect/Transform.h"
#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Rewrite.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -126,3 +131,210 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
//===---------------------------------------------------------------------===//
// TransformRewriter
//===---------------------------------------------------------------------===//
/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
mlir::transform::TransformRewriter *t = unwrap(rewriter);
mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
return wrap(base);
}
//===---------------------------------------------------------------------===//
// TransformResults
//===---------------------------------------------------------------------===//
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
intptr_t numOps, MlirOperation *ops) {
SmallVector<Operation *> opsVec;
opsVec.reserve(numOps);
for (intptr_t i = 0; i < numOps; ++i)
opsVec.push_back(unwrap(ops[i]));
unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
}
void mlirTransformResultsSetValues(MlirTransformResults results,
MlirValue result, intptr_t numValues,
MlirValue *values) {
SmallVector<Value> valuesVec;
valuesVec.reserve(numValues);
for (intptr_t i = 0; i < numValues; ++i)
valuesVec.push_back(unwrap(values[i]));
unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
}
void mlirTransformResultsSetParams(MlirTransformResults results,
MlirValue result, intptr_t numParams,
MlirAttribute *params) {
SmallVector<Attribute> paramsVec;
paramsVec.reserve(numParams);
for (intptr_t i = 0; i < numParams; ++i)
paramsVec.push_back(unwrap(params[i]));
unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
}
//===---------------------------------------------------------------------===//
// TransformState
//===---------------------------------------------------------------------===//
void mlirTransformStateForEachPayloadOp(MlirTransformState state,
MlirValue value,
MlirOperationCallback callback,
void *userData) {
for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
callback(wrap(op), userData);
}
void mlirTransformStateForEachPayloadValue(MlirTransformState state,
MlirValue value,
MlirValueCallback callback,
void *userData) {
for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
callback(wrap(val), userData);
}
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
MlirAttributeCallback callback,
void *userData) {
for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
callback(wrap(attr), userData);
}
//===---------------------------------------------------------------------===//
// TransformOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirTransformOpInterfaceTypeID(void) {
return wrap(transform::TransformOpInterface::getInterfaceID());
}
/// Fallback model for the TransformOpInterface that uses C API callbacks.
class TransformOpInterfaceFallbackModel
: public mlir::transform::TransformOpInterface::FallbackModel<
TransformOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~TransformOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return transform::TransformOpInterface::getInterfaceID();
}
static bool classof(const mlir::transform::detail::
TransformOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
::mlir::DiagnosedSilenceableFailure
apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state) const {
assert(callbacks.apply && "apply callback not set");
MlirDiagnosedSilenceableFailure status =
callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
wrap(&state), callbacks.userData);
switch (status) {
case MlirDiagnosedSilenceableFailureSuccess:
return DiagnosedSilenceableFailure::success();
case MlirDiagnosedSilenceableFailureSilenceableFailure:
// TODO: enable passing diagnostic info from C API to C++ API.
return DiagnosedSilenceableFailure::silenceableFailure(std::move(
*(op->emitError()
<< "TransformOpInterfaceFallbackModel: silenceable failure")
.getUnderlyingDiagnostic()));
case MlirDiagnosedSilenceableFailureDefiniteFailure:
return DiagnosedSilenceableFailure::definiteFailure();
}
llvm_unreachable("unknown transform status");
}
bool allowsRepeatedHandleOperands(Operation *op) const {
assert(callbacks.allowsRepeatedHandleOperands &&
"allowsRepeatedHandleOperands callback not set");
return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
}
private:
MlirTransformOpInterfaceCallbacks callbacks;
};
/// Attach a TransformOpInterface FallbackModel to the given named operation.
/// The FallbackModel uses the provided callbacks to implement the interface.
void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context.
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<TransformOpInterfaceFallbackModel>(
opInfo->getInterface<TransformOpInterfaceFallbackModel>());
assert(model && "Failed to get TransformOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
/// Set the effect for the operands to only read the transform handles.
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::onlyReadsHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the operands to consuming the transform handles.
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::consumesHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the results to that they produce transform handles.
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
MlirMemoryEffectInstancesList effects) {
// NB: calling `producesHandle()` `numResults` as we cannot cast array of
// `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
// to Python). `producesHandle` iterates over the given `ResultRange` anyway.
SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
for (intptr_t i = 0; i < numResults; ++i) {
auto opResult = cast<OpResult>(unwrap(results[i]));
transform::producesHandle(ResultRange(opResult), effectList);
}
}
/// Set the effect of potentially modifying payload IR.
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
transform::modifiesPayload(*unwrap(effects));
}
/// Set the effect of potentially reading payload IR.
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
transform::onlyReadsPayload(*unwrap(effects));
}

View File

@@ -30,7 +30,6 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/ThreadPool.h"
#include <cstddef>
#include <memory>
@@ -714,6 +713,10 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
}
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos) {
return wrap(&unwrap(op)->getOpOperand(static_cast<unsigned>(pos)));
}
void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue) {
unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));

View File

@@ -167,3 +167,73 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
}
return mlirLogicalResultSuccess();
}
//===---------------------------------------------------------------------===//
// MemoryEffectOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID() {
return wrap(MemoryEffectOpInterface::getInterfaceID());
}
/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
class MemoryEffectOpInterfaceFallbackModel
: public mlir::MemoryEffectOpInterface::FallbackModel<
MemoryEffectOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~MemoryEffectOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return MemoryEffectOpInterface::getInterfaceID();
}
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
void
getEffects(Operation *op,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) const {
assert(callbacks.getEffects && "getEffects callback not set");
MlirMemoryEffectInstancesList cEffects = wrap(&effects);
callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
}
private:
MlirMemoryEffectsOpInterfaceCallbacks callbacks;
};
/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
/// The FallbackModel uses the provided callbacks to implement the interface.
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}

View File

@@ -687,8 +687,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectTransform.cpp
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport
MLIRPythonExtension.Core
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect

View File

@@ -242,6 +242,7 @@ def _site_initialize():
Sequence.register(ir.BlockPredecessors)
Sequence.register(ir.OperationList)
Sequence.register(ir.OpOperandList)
Sequence.register(ir.OpOperands)
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)

View File

@@ -29,6 +29,8 @@ __all__ = [
"Dialect",
"Operand",
"Result",
"register_dialect",
"register_operation",
"Region",
"Operation",
]
@@ -36,6 +38,8 @@ __all__ = [
Operand = ir.Value
Result = ir.OpResult
Region = ir.Region
register_dialect = _cext.register_dialect
register_operation = _cext.register_operation
class ConstraintLoweringContext:
@@ -203,6 +207,12 @@ class Operation(ir.OpView):
Use `Dialect` and `.Operation` of `Dialect` subclasses instead.
"""
def __init__(*args, **kwargs):
raise TypeError(
"This class is a template and cannot be instantiated directly. "
"Please use a subclass that defines the operation."
)
@classmethod
def __init_subclass__(
cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
@@ -507,22 +517,21 @@ class Dialect(ir.Dialect):
return m
@classmethod
def load(cls) -> None:
if hasattr(cls, "_mlir_module"):
raise RuntimeError(f"Dialect {cls.name} is already loaded.")
mlir_module = cls._emit_module()
def load(cls, register=True, reload=False) -> None:
if hasattr(cls, "_mlir_module") and not reload:
return
cls._mlir_module = cls._emit_module()
pm = PassManager()
pm.add("canonicalize, cse")
pm.run(mlir_module.operation)
pm.run(cls._mlir_module.operation)
irdl.load_dialects(mlir_module)
irdl.load_dialects(cls._mlir_module)
_cext.register_dialect(cls)
if register:
register_dialect(cls)
for op in cls.operations:
op._attach_traits()
_cext.register_operation(cls)(op)
cls._mlir_module = mlir_module
register_dialect_operation = register_operation(cls)
for op in cls.operations:
op._attach_traits()
register_dialect_operation(op)

View File

@@ -0,0 +1,500 @@
# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
from typing import Sequence
from contextlib import contextmanager
from mlir import ir
from mlir.dialects import index, transform, func, arith, ext
from mlir.dialects.transform import (
DiagnosedSilenceableFailure,
AnyOpType,
AnyValueType,
AnyParamType,
structured,
interpreter,
)
@ext.register_dialect
class MyTransform(ext.Dialect, name="my_transform"):
pass
def run(emit_schedule):
print(f"Test: {emit_schedule.__name__}")
with ir.Context() as ctx, ir.Location.unknown():
payload = emit_payload()
MyTransform.load(register=False, reload=True)
GetNamedAttributeOp.attach_interface_impls(ctx)
PrintParamOp.attach_interface_impls(ctx)
# NB: Other newly defined my_transform ops have their interfaces attached
# in their respective test functions.
schedule = emit_schedule()
interpreter.apply_named_sequence(
payload,
_named_seq := schedule.operation.regions[0].blocks[0].operations[0],
schedule,
)
# Payload used by all tests
def emit_payload():
payload_module = ir.Module.create()
with ir.InsertionPoint(payload_module.body):
f32 = ir.F32Type.get()
@func.FuncOp.from_py_func(f32, f32, results=[f32])
def name_of_func(a, b):
c = arith.addf(a, b)
i32 = ir.IntegerType.get_signless(32)
arith.constant(i32, 42)
arith.constant(i32, 24)
func.ReturnOp([c])
return payload_module
@contextmanager
def schedule_boilerplate():
schedule = ir.Module.create()
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
with ir.InsertionPoint(schedule.body):
named_sequence = transform.NamedSequenceOp(
"__transform_main",
[AnyOpType.get()],
[AnyOpType.get()],
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
)
with ir.InsertionPoint(named_sequence.body):
yield schedule, named_sequence
# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
# Used by most ops defined below.
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.only_reads_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.only_reads_payload(effects)
# Demonstration of a TransformOpInterface-implementing op that gets named attributes
# from target ops and produces them as param handles.
@ext.register_operation(MyTransform)
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
target: ext.Operand[transform.AnyOpType]
attr_name: ir.StringAttr
attr_as_param: ext.Result[transform.AnyParamType[()]]
@classmethod
def attach_interface_impls(cls, ctx=None):
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "GetNamedAttributeOp",
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
associated_attrs = []
for target_op in target_ops:
assoc_attr = target_op.attributes.get(op.attr_name.value)
if assoc_attr is None:
return DiagnosedSilenceableFailure.RecoverableFailure
associated_attrs.append(assoc_attr)
results.set_params(op.attr_as_param, associated_attrs)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
@ext.register_operation(MyTransform)
class PrintParamOp(MyTransform.Operation, name="print_param"):
target: ext.Operand[transform.AnyParamType]
name: ir.StringAttr
@classmethod
def attach_interface_impls(cls, ctx=None):
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: "PrintParamOp",
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_attrs = state.get_params(op.target)
print(f"[[[ IR printer: {op.name.value} ]]]")
for attr in target_attrs:
print(attr)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
return False
# Syntax for an op with one op handle operand and one op handle result.
@ext.register_operation(MyTransform)
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
target: ext.Operand[transform.AnyOpType]
res: ext.Result[transform.AnyOpType[()]]
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
@run
def OneOpInOneOpOutTransformOpInterface():
"""Tests a simple passthrough interface implementation.
Checks that the target ops are correctly identified and passed as results.
"""
# Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OneOpInOneOpOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
target_ops = state.get_payload_ops(op.target)
target_names = [t.name.value for t in target_ops]
print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
results.set_ops(op.res, target_ops)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
return False
# Attach the interface implementation to the op.
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
# CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
out = OneOpInOneOpOut(func_handle).result
# CHECK: Output handle from OneOpInOneOpOut
# CHECK-NEXT: func.func @name_of_func
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
transform.YieldOp([out])
return schedule
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
@run
def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
"""Tests an interface implementation using the rewriter to modify the IR.
Checks that `arith.constant` ops are replaced by `index.constant` ops and
that the results are correctly updated.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OneOpInOneOpOut,
rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
result_ops = []
for target_op in state.get_payload_ops(op.target):
with ir.InsertionPoint(target_op):
index_version = index.constant(target_op.value.value)
result_ops.append(index_version.owner)
rewriter.replace_op(target_op, [index_version])
results.set_ops(op.res, result_ops)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
return False
# Attach the interface implementation to the op.
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
@staticmethod
def get_effects(op: ir.Operation, effects):
transform.consumes_handle(op.op_operands, effects)
transform.produces_handle(op.results, effects)
transform.modifies_payload(effects)
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
csts_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.constant"]
).result
# CHECK: Before replacement:
# CHECK-NOT: index.constant
# CHECK-DAG: arith.constant 42 : i32
# CHECK-DAG: arith.constant 24 : i32
transform.PrintOp(target=func_handle, name="Before replacement:")
out = OneOpInOneOpOut(csts_handle).result
# CHECK: After replacement:
# CHECK-NOT: arith.constant
# CHECK-DAG: index.constant 42
# CHECK-DAG: index.constant 24
transform.PrintOp(target=func_handle, name="After replacement:")
# CHECK: Output handle from OneOpInOneOpOut:
# CHECK-NEXT: index.constant 42
# CHECK-NEXT: index.constant 24
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut:")
transform.YieldOp([out])
return schedule
@ext.register_operation(MyTransform)
class OpValParamInParamOpValOut(
MyTransform.Operation, name="op_val_param_in_param_op_val_out"
):
# operands
op_arg: ext.Operand[transform.AnyOpType]
val_arg: ext.Operand[transform.AnyValueType]
param_arg: ext.Operand[transform.AnyParamType]
# results
param_res: ext.Result[transform.AnyParamType[()]]
op_res: ext.Result[transform.AnyOpType[()]]
value_res: ext.Result[transform.AnyValueType[()]]
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
@run
def OpValParamInParamOpValOutTransformOpInterface():
"""Tests an interface implementation involving Op, Value, and Param types.
Checks that payload ops, values, and parameters are correctly permuted and
propagated and accessible from the (permuted) result handles.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpValParamInParamOpValOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
ops = state.get_payload_ops(op.op_arg)
values = state.get_payload_values(op.val_arg)
params = state.get_params(op.param_arg)
print(
f"OpValParamInParamOpValOutTransformOpInterface: ops={len(ops)}, values={len(values)}, params={len(params)}"
)
results.set_params(op.param_res, params)
results.set_ops(op.op_res, ops)
results.set_values(op.value_res, values)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
return False
TransformOpInterfaceFallbackModel.attach(OpValParamInParamOpValOut.OPERATION_NAME)
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
MemoryEffectsOpInterfaceFallbackModel.attach(
OpValParamInParamOpValOut.OPERATION_NAME
)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
addf_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.addf"]
).result
func_and_addf = transform.MergeHandlesOp([func_handle, addf_handle])
value_handle = transform.GetResultOp(
AnyValueType.get(), addf_handle, [0]
).result
param_handle = transform.ParamConstantOp(
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
).param
# CHECK: OpValParamInParamOpValOutTransformOpInterface: ops=2, values=1, params=1
op_val_param_op = OpValParamInParamOpValOut(
func_and_addf, value_handle, param_handle
)
# CHECK: Ops passed through OpValParamInParamOpValOut:
# CHECK-NEXT: func.func
# CHECK: arith.addf
transform.PrintOp(
target=op_val_param_op.op_res,
name="Ops passed through OpValParamInParamOpValOut:",
)
# CHECK: Ops defining values passed through OpValParamInParamOpValOut:
# CHECK-NEXT: arith.addf
addf_as_res = transform.GetDefiningOp(
transform.AnyOpType.get(), op_val_param_op.value_res
).result
transform.PrintOp(
target=addf_as_res,
name="Ops defining values passed through OpValParamInParamOpValOut:",
)
# CHECK: Parameter passed through OpValParamInParamOpValOut:
# CHECK-NEXT: 42 : i32
PrintParamOp(
op_val_param_op.param_res,
name=ir.StringAttr.get(
"Parameter passed through OpValParamInParamOpValOut:"
),
)
transform.YieldOp([op_val_param_op.op_res])
named_seq.verify()
return schedule
@ext.register_operation(MyTransform)
class OpsParamsInValuesParamOut(
MyTransform.Operation, name="ops_params_in_values_param_out"
):
# operands
ops: Sequence[ext.Operand[transform.AnyOpType]]
params: Sequence[ext.Operand[transform.AnyParamType]]
# results
values: Sequence[ext.Result[transform.AnyValueType]]
param: ext.Result[transform.AnyParamType]
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
@run
def OpsParamsInValuesParamOutTransformOpInterface():
"""Tests an interface with variadic Op and Param operands and variadic Value results.
Checks correct handling of multiple handles, parameter aggregation, and
result generation.
"""
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
@staticmethod
def apply(
op: OpsParamsInValuesParamOut,
_rewriter: transform.TransformRewriter,
results: transform.TransformResults,
state: transform.TransformState,
) -> DiagnosedSilenceableFailure:
ops_count = 0
value_handles = []
for op_handle in op.ops:
ops = state.get_payload_ops(op_handle)
ops_count += len(ops)
value_handles.append([i for op in ops for i in op.results])
param_count = 0
param_sum = 0
for param_handle in op.params:
params = state.get_params(param_handle)
param_count += len(params)
param_sum += sum(p.value for p in params)
print(
f"OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count={ops_count}, param_count={param_count}"
)
assert len(op.values) == len(op.ops)
for value_res_handle, value_vector in zip(op.values, value_handles):
results.set_values(value_res_handle, value_vector)
results.set_params(
op.param,
[ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
)
return DiagnosedSilenceableFailure.Success
@staticmethod
def allow_repeated_handle_operands(_op: OpsParamsInValuesParamOut) -> bool:
return False
TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
MemoryEffectsOpInterfaceFallbackModel.attach(
OpsParamsInValuesParamOut.OPERATION_NAME
)
with schedule_boilerplate() as (schedule, named_seq):
func_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["func.func"]
).result
csts_handle = structured.MatchOp.match_op_names(
named_seq.bodyTarget, ["arith.constant"]
).result
csts_as_param = GetNamedAttributeOp(
csts_handle, attr_name=ir.StringAttr.get("value")
).attr_as_param
param_handle = transform.ParamConstantOp(
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
).param
# CHECK: OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count=3, param_count=3
op = OpsParamsInValuesParamOut(
[transform.AnyValueType.get()] * 2,
transform.AnyParamType.get(),
[func_handle, csts_handle],
[csts_as_param, param_handle],
)
empty_handle = transform.GetDefiningOp(transform.AnyOpType.get(), op.values[0])
# CHECK: Defining op of value result 0
transform.PrintOp(
target=empty_handle.result, name="Defining op of value result 0"
)
# NB: no result on the func.func, so output is expected to be empty
cst1_res, cst2_res = transform.SplitHandleOp(
[transform.AnyValueType.get()] * 2, op.values[1]
).results
cst1_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst1_res)
# CHECK-NEXT: Defining op of first constant
# CHECK-NEXT: arith.constant 42 : i32
transform.PrintOp(
target=cst1_again.result, name="Defining op of first constant"
)
cst2_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst2_res)
# CHECK-NEXT: Defining op of second constant
# CHECK-NEXT: arith.constant 24 : i32
transform.PrintOp(
target=cst2_again.result, name="Defining op of second constant"
)
# CHECK: Sum of params:
# CHECK-NEXT: 189 : i32
PrintParamOp(op.param, name=ir.StringAttr.get("Sum of params:"))
transform.YieldOp([func_handle])
named_seq.verify()
return schedule