[MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (#176920)
Provides the infrastructure for implementing and late-binding
OpInterfaces from Python.
* On the mlir-c API declaration side, each `XOpInterface` has a callback
struct, with a callback for each method and a userdata member (provided
as an arg to each method), and a
`mlirXOpInterfaceAttachFallbackModel(ctx, op_name, callbacks)` func.
* This CAPI is implemented by defining a subclass of
`XOpInterface::FallbackModel` that holds the callback struct and has
each method call the corresponding callback (with userdata as an arg).
Given a callback struct, a new `FallbackModel` is created and attached,
i.e. late bound, to the named op. (MLIR's interface infrastructure is
such that the thus registered `FallbackModel` will be returned in case
the op gets cast to the `XOpInterface`.)
* On the Python side, we expose a stand-in `XOpInterface` base class
which has one (class)method: `XOpInterface.attach(cls, op_name, ctx)`.
Python users subclass this class (`class MyInterfaceImpl(XOpInterface):
...`) and implement the interface's methods (with the right names and
signatures). The user calls `attach` on the subclass
(`MyInterfaceImpl.attach("my_dialect.my_op", ctx)`) which prepares the
callbacks struct _with userdata set to the subclass_ (as we use it to
lookup methods). These callbacks (and userdata) are then registered as
an `XOpInterface::FallbackModel` by
`mlirXOpInterfaceAttachFallbackModel(...)`. From then on the Python
methods will be used to respond to calls to the interface methods
(originating in C++).
This PR enables implementing the TransformOpInterface and the
MemoryEffectsOpInterface, both of which are required for making an op
into a transform op.
Everything besides the above linked code is there to facilitate exposing
the interfaces: the right types for the arguments of the methods are
exposed as are functions/methods for manipulating these arguments (e.g.
specifying side effects on `OpOperand`s and `OpResult`s and being able
to access and set the transform handles associated with args and
results).
This commit is contained in:
@@ -11,6 +11,8 @@
|
||||
#define MLIR_C_DIALECT_TRANSFORM_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Interfaces.h"
|
||||
#include "mlir-c/Rewrite.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -19,6 +21,32 @@ extern "C" {
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
|
||||
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(MlirTransformResults, void);
|
||||
DEFINE_C_API_STRUCT(MlirTransformRewriter, void);
|
||||
DEFINE_C_API_STRUCT(MlirTransformState, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// DiagnosedSilenceableFailure
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Enum representing the result of a transform operation.
|
||||
typedef enum {
|
||||
/// The operation succeeded.
|
||||
MlirDiagnosedSilenceableFailureSuccess,
|
||||
/// The operation failed in a silenceable way.
|
||||
MlirDiagnosedSilenceableFailureSilenceableFailure,
|
||||
/// The operation failed definitively.
|
||||
MlirDiagnosedSilenceableFailureDefiniteFailure
|
||||
} MlirDiagnosedSilenceableFailure;
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===---------------------------------------------------------------------===//
|
||||
@@ -86,6 +114,126 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformRewriter
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Cast the TransformRewriter to a RewriterBase
|
||||
MLIR_CAPI_EXPORTED MlirRewriterBase
|
||||
mlirTransformRewriterAsBase(MlirTransformRewriter rewriter);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformResults
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Set the payload operations for a transform result by iterating over a list.
|
||||
MLIR_CAPI_EXPORTED void mlirTransformResultsSetOps(MlirTransformResults results,
|
||||
MlirValue result,
|
||||
intptr_t numOps,
|
||||
MlirOperation *ops);
|
||||
|
||||
/// Set the payload values for a transform result by iterating over a list.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformResultsSetValues(MlirTransformResults results, MlirValue result,
|
||||
intptr_t numValues, MlirValue *values);
|
||||
|
||||
/// Set the parameters for a transform result by iterating over a list.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformResultsSetParams(MlirTransformResults results, MlirValue result,
|
||||
intptr_t numParams, MlirAttribute *params);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformState
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Callback for iterating over payload operations.
|
||||
typedef void (*MlirOperationCallback)(MlirOperation, void *userData);
|
||||
|
||||
/// Iterate over payload operations associated with the transform IR value.
|
||||
/// Calls the callback for each payload operation.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformStateForEachPayloadOp(MlirTransformState state, MlirValue value,
|
||||
MlirOperationCallback callback,
|
||||
void *userData);
|
||||
|
||||
/// Callback for iterating over payload values.
|
||||
typedef void (*MlirValueCallback)(MlirValue, void *userData);
|
||||
|
||||
/// Iterate over payload values associated with the transform IR value.
|
||||
/// Calls the callback for each payload value.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformStateForEachPayloadValue(MlirTransformState state, MlirValue value,
|
||||
MlirValueCallback callback,
|
||||
void *userData);
|
||||
|
||||
/// Callback for iterating over parameters.
|
||||
typedef void (*MlirAttributeCallback)(MlirAttribute, void *userData);
|
||||
|
||||
/// Iterate over parameters associated with the transform IR value.
|
||||
/// Calls the callback for each parameter.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
|
||||
MlirAttributeCallback callback, void *userData);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the TransformOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOpInterfaceTypeID(void);
|
||||
|
||||
/// Callbacks for implementing TransformOpInterface from external code.
|
||||
typedef struct {
|
||||
/// Optional constructor for the user data.
|
||||
/// Set to nullptr to disable it.
|
||||
void (*construct)(void *userData);
|
||||
/// Optional destructor for the user data.
|
||||
/// Set to nullptr to disable it.
|
||||
void (*destruct)(void *userData);
|
||||
/// Apply callback that implements the transformation.
|
||||
MlirDiagnosedSilenceableFailure (*apply)(MlirOperation op,
|
||||
MlirTransformRewriter rewriter,
|
||||
MlirTransformResults results,
|
||||
MlirTransformState state,
|
||||
void *userData);
|
||||
/// Callback to check if repeated handle operands are allowed.
|
||||
bool (*allowsRepeatedHandleOperands)(MlirOperation op, void *userData);
|
||||
void *userData;
|
||||
} MlirTransformOpInterfaceCallbacks;
|
||||
|
||||
/// Attach TransformOpInterface to the operation with the given name using
|
||||
/// the provided callbacks.
|
||||
MLIR_CAPI_EXPORTED void mlirTransformOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirTransformOpInterfaceCallbacks callbacks);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// Transform-specifc MemoryEffectsOpInterface helpers
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Helper to mark operands as only reading handles.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
|
||||
MlirMemoryEffectInstancesList effects);
|
||||
|
||||
/// Helper to mark operands as consuming handles.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
|
||||
MlirMemoryEffectInstancesList effects);
|
||||
|
||||
/// Helper to mark results as producing handles.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
|
||||
MlirMemoryEffectInstancesList effects);
|
||||
|
||||
/// Helper to mark potential modifications to the payload IR.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects);
|
||||
|
||||
/// Helper to mark potential reads from the payload IR.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -673,6 +673,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op);
|
||||
MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op,
|
||||
intptr_t pos);
|
||||
|
||||
/// Returns `pos`-th OpOperand of the operation.
|
||||
MLIR_CAPI_EXPORTED MlirOpOperand mlirOperationGetOpOperand(MlirOperation op,
|
||||
intptr_t pos);
|
||||
|
||||
/// Sets the `pos`-th operand of the operation.
|
||||
MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
|
||||
MlirValue newValue);
|
||||
|
||||
@@ -22,6 +22,16 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(MlirMemoryEffectInstancesList, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
/// Returns `true` if the given operation implements an interface identified by
|
||||
/// its TypeID.
|
||||
MLIR_CAPI_EXPORTED bool
|
||||
@@ -42,7 +52,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the InferTypeOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(void);
|
||||
|
||||
/// These callbacks are used to return multiple types from functions while
|
||||
/// transferring ownership to the caller. The first argument is the number of
|
||||
@@ -65,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the InferShapedTypeOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(void);
|
||||
|
||||
/// These callbacks are used to return multiple shaped type components from
|
||||
/// functions while transferring ownership to the caller. The first argument is
|
||||
@@ -87,6 +97,31 @@ mlirInferShapedTypeOpInterfaceInferReturnTypes(
|
||||
void *properties, intptr_t nRegions, MlirRegion *regions,
|
||||
MlirShapedTypeComponentsCallback callback, void *userData);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MemoryEffectsOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the MemoryEffectsOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirMemoryEffectsOpInterfaceTypeID(void);
|
||||
|
||||
/// Callbacks for implementing MemoryEffectsOpInterface from external code.
|
||||
typedef struct {
|
||||
/// Optional constructor for user data. Set to nullptr to disable it.
|
||||
void (*construct)(void *userData);
|
||||
/// Optional destructor for user data. Set to nullptr to disable it.
|
||||
void (*destruct)(void *userData);
|
||||
/// Get memory effects callback.
|
||||
void (*getEffects)(MlirOperation op, MlirMemoryEffectInstancesList effects,
|
||||
void *userData);
|
||||
void *userData;
|
||||
} MlirMemoryEffectsOpInterfaceCallbacks;
|
||||
|
||||
/// Attach a new FallbackModel for the MemoryEffectsOpInterface to the named
|
||||
/// operation. The FallbackModel will call the provided callbacks.
|
||||
MLIR_CAPI_EXPORTED void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirMemoryEffectsOpInterfaceCallbacks callbacks);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1492,6 +1492,7 @@ private:
|
||||
class MLIR_PYTHON_API_EXPORTED PyOpOperand {
|
||||
public:
|
||||
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
|
||||
operator MlirOpOperand() const { return opOperand; }
|
||||
|
||||
nanobind::typed<nanobind::object, PyOpView> getOwner() const;
|
||||
|
||||
@@ -1871,13 +1872,20 @@ public:
|
||||
MLIR_PYTHON_API_EXPORTED MlirValue getUniqueResult(MlirOperation operation);
|
||||
MLIR_PYTHON_API_EXPORTED void populateIRCore(nanobind::module_ &m);
|
||||
MLIR_PYTHON_API_EXPORTED void populateRoot(nanobind::module_ &m);
|
||||
|
||||
/// Helper for creating an @classmethod.
|
||||
template <class Func, typename... Args>
|
||||
inline nanobind::object classmethod(Func f, Args... args) {
|
||||
nanobind::object cf = nanobind::cpp_function(f, args...);
|
||||
return nanobind::borrow<nanobind::object>((PyClassMethod_New(cf.ptr())));
|
||||
}
|
||||
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
namespace nanobind {
|
||||
namespace detail {
|
||||
|
||||
template <>
|
||||
struct type_caster<
|
||||
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext>
|
||||
|
||||
28
mlir/include/mlir/CAPI/Dialect/Transform.h
Normal file
28
mlir/include/mlir/CAPI/Dialect/Transform.h
Normal file
@@ -0,0 +1,28 @@
|
||||
//===- Transform.h - C API Utils for Transform dialect ----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains declarations of implementation details of the C API for
|
||||
// the Transform dialect. This file should not be included from C++ code other
|
||||
// than C API implementation nor from C code.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CAPI_DIALECT_TRANSFORM_H
|
||||
#define MLIR_CAPI_DIALECT_TRANSFORM_H
|
||||
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
|
||||
DEFINE_C_API_PTR_METHODS(MlirTransformRewriter,
|
||||
mlir::transform::TransformRewriter)
|
||||
DEFINE_C_API_PTR_METHODS(MlirTransformResults,
|
||||
mlir::transform::TransformResults)
|
||||
DEFINE_C_API_PTR_METHODS(MlirTransformState, mlir::transform::TransformState)
|
||||
|
||||
#endif // MLIR_CAPI_DIALECT_TRANSFORM_H
|
||||
@@ -15,4 +15,12 @@
|
||||
#ifndef MLIR_CAPI_INTERFACES_H
|
||||
#define MLIR_CAPI_INTERFACES_H
|
||||
|
||||
#include "mlir-c/Interfaces.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
DEFINE_C_API_PTR_METHODS(
|
||||
MlirMemoryEffectInstancesList,
|
||||
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance>)
|
||||
|
||||
#endif // MLIR_CAPI_INTERFACES_H
|
||||
|
||||
@@ -8,12 +8,14 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "IRInterfaces.h"
|
||||
#include "Rewrite.h"
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "nanobind/nanobind.h"
|
||||
#include <nanobind/trampoline.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir::python::nanobind_adaptors;
|
||||
@@ -22,6 +24,227 @@ namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
namespace transform {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransformRewriter
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "TransformRewriter";
|
||||
|
||||
PyTransformRewriter(MlirTransformRewriter rewriter)
|
||||
: PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransformResults
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PyTransformResults {
|
||||
public:
|
||||
PyTransformResults(MlirTransformResults results) : results(results) {}
|
||||
|
||||
MlirTransformResults get() const { return results; }
|
||||
|
||||
void setOps(PyValue &result, const nb::list &ops) {
|
||||
std::vector<MlirOperation> opsVec;
|
||||
opsVec.reserve(ops.size());
|
||||
for (auto op : ops) {
|
||||
opsVec.push_back(nb::cast<MlirOperation>(op));
|
||||
}
|
||||
mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
|
||||
}
|
||||
|
||||
void setValues(PyValue &result, const nb::list &values) {
|
||||
std::vector<MlirValue> valuesVec;
|
||||
valuesVec.reserve(values.size());
|
||||
for (auto item : values) {
|
||||
valuesVec.push_back(nb::cast<MlirValue>(item));
|
||||
}
|
||||
mlirTransformResultsSetValues(results, result, valuesVec.size(),
|
||||
valuesVec.data());
|
||||
}
|
||||
|
||||
void setParams(PyValue &result, const nb::list ¶ms) {
|
||||
std::vector<MlirAttribute> paramsVec;
|
||||
paramsVec.reserve(params.size());
|
||||
for (auto item : params) {
|
||||
paramsVec.push_back(nb::cast<MlirAttribute>(item));
|
||||
}
|
||||
mlirTransformResultsSetParams(results, result, paramsVec.size(),
|
||||
paramsVec.data());
|
||||
}
|
||||
|
||||
static void bind(nanobind::module_ &m) {
|
||||
nb::class_<PyTransformResults>(m, "TransformResults")
|
||||
.def(nb::init<MlirTransformResults>())
|
||||
.def("set_ops", &PyTransformResults::setOps,
|
||||
"Set the payload operations for a transform result.",
|
||||
nb::arg("result"), nb::arg("ops"))
|
||||
.def("set_values", &PyTransformResults::setValues,
|
||||
"Set the payload values for a transform result.",
|
||||
nb::arg("result"), nb::arg("values"))
|
||||
.def("set_params", &PyTransformResults::setParams,
|
||||
"Set the parameters for a transform result.", nb::arg("result"),
|
||||
nb::arg("params"));
|
||||
}
|
||||
|
||||
private:
|
||||
MlirTransformResults results;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransformState
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PyTransformState {
|
||||
public:
|
||||
PyTransformState(MlirTransformState state) : state(state) {}
|
||||
|
||||
MlirTransformState get() const { return state; }
|
||||
|
||||
static void bind(nanobind::module_ &m) {
|
||||
nb::class_<PyTransformState>(m, "TransformState")
|
||||
.def(nb::init<MlirTransformState>())
|
||||
.def("get_payload_ops", &PyTransformState::getPayloadOps,
|
||||
"Get the payload operations associated with a transform IR value.",
|
||||
nb::arg("operand"))
|
||||
.def("get_payload_values", &PyTransformState::getPayloadValues,
|
||||
"Get the payload values associated with a transform IR value.",
|
||||
nb::arg("operand"))
|
||||
.def("get_params", &PyTransformState::getParams,
|
||||
"Get the parameters (attributes) associated with a transform IR "
|
||||
"value.",
|
||||
nb::arg("operand"));
|
||||
}
|
||||
|
||||
private:
|
||||
nanobind::list getPayloadOps(PyValue &value) {
|
||||
nanobind::list result;
|
||||
mlirTransformStateForEachPayloadOp(
|
||||
state, value,
|
||||
[](MlirOperation op, void *userData) {
|
||||
PyMlirContextRef context =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
static_cast<nanobind::list *>(userData)->append(opview);
|
||||
},
|
||||
&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
nanobind::list getPayloadValues(PyValue &value) {
|
||||
nanobind::list result;
|
||||
mlirTransformStateForEachPayloadValue(
|
||||
state, value,
|
||||
[](MlirValue val, void *userData) {
|
||||
static_cast<nanobind::list *>(userData)->append(val);
|
||||
},
|
||||
&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
nanobind::list getParams(PyValue &value) {
|
||||
nanobind::list result;
|
||||
mlirTransformStateForEachParam(
|
||||
state, value,
|
||||
[](MlirAttribute attr, void *userData) {
|
||||
static_cast<nanobind::list *>(userData)->append(attr);
|
||||
},
|
||||
&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
MlirTransformState state;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransformOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
class PyTransformOpInterface
|
||||
: public PyConcreteOpInterface<PyTransformOpInterface> {
|
||||
public:
|
||||
using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
|
||||
|
||||
constexpr static const char *pyClassName = "TransformOpInterface";
|
||||
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
||||
&mlirTransformOpInterfaceTypeID;
|
||||
|
||||
/// Attach a new TransformOpInterface FallbackModel to the named operation.
|
||||
/// The FallbackModel acts as a trampoline for callbacks on the Python class.
|
||||
static void attach(nb::object &target, const std::string &opName,
|
||||
DefaultingPyMlirContext ctx) {
|
||||
// Prepare the callbacks that will be used by the FallbackModel.
|
||||
MlirTransformOpInterfaceCallbacks callbacks;
|
||||
// Make the pointer to the Python class available to the callbacks.
|
||||
callbacks.userData = target.ptr();
|
||||
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
|
||||
|
||||
// The above ref bump is all we need as initialization, no need to run the
|
||||
// construct callback.
|
||||
callbacks.construct = nullptr;
|
||||
// Upon the FallbackModel's destruction, drop the ref to the Python class.
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
// The apply callback which calls into Python.
|
||||
callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
|
||||
MlirTransformResults results, MlirTransformState state,
|
||||
void *userData) -> MlirDiagnosedSilenceableFailure {
|
||||
nb::handle pyClass(static_cast<PyObject *>(userData));
|
||||
|
||||
auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
|
||||
|
||||
auto pyRewriter = PyTransformRewriter(rewriter);
|
||||
auto pyResults = PyTransformResults(results);
|
||||
auto pyState = PyTransformState(state);
|
||||
|
||||
// Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
|
||||
// staticmethod.
|
||||
PyMlirContextRef context =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
|
||||
|
||||
return nb::cast<MlirDiagnosedSilenceableFailure>(res);
|
||||
};
|
||||
|
||||
// The allows_repeated_handle_operands callback which calls into Python.
|
||||
callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
|
||||
void *userData) -> bool {
|
||||
nb::handle pyClass(static_cast<PyObject *>(userData));
|
||||
|
||||
auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
|
||||
nb::getattr(pyClass, "allow_repeated_handle_operands"));
|
||||
|
||||
// Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
|
||||
// staticmethod.
|
||||
PyMlirContextRef context =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
nb::object res = pyAllowRepeatedHandleOperands(opview);
|
||||
|
||||
return nb::cast<bool>(res);
|
||||
};
|
||||
|
||||
// Attach a FallbackModel, which calls into Python, to the named operation.
|
||||
mlirTransformOpInterfaceAttachFallbackModel(
|
||||
ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.attr("attach") = classmethod(
|
||||
[](const nb::object &cls, const nb::object &opName, nb::object target,
|
||||
DefaultingPyMlirContext context) {
|
||||
if (target.is_none())
|
||||
target = cls;
|
||||
return attach(target, nb::cast<std::string>(opName), context);
|
||||
},
|
||||
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
|
||||
nb::arg("target").none() = nb::none(),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Attach the interface subclass to the given operation name.");
|
||||
}
|
||||
};
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===-------------------------------------------------------------------===//
|
||||
@@ -162,12 +385,81 @@ struct ParamType : PyConcreteType<ParamType> {
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MemoryEffectsOpInterface helpers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
void onlyReadsHandle(nb::iterable &operands,
|
||||
PyMemoryEffectsInstanceList effects) {
|
||||
std::vector<MlirOpOperand> operandsVec;
|
||||
for (auto operand : operands)
|
||||
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
|
||||
mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
|
||||
effects.effects);
|
||||
};
|
||||
|
||||
void consumesHandle(nb::iterable &operands,
|
||||
PyMemoryEffectsInstanceList effects) {
|
||||
std::vector<MlirOpOperand> operandsVec;
|
||||
for (auto operand : operands)
|
||||
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
|
||||
mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
|
||||
effects.effects);
|
||||
};
|
||||
|
||||
void producesHandle(nb::iterable &results,
|
||||
PyMemoryEffectsInstanceList effects) {
|
||||
std::vector<MlirValue> resultsVec;
|
||||
for (auto result : results)
|
||||
resultsVec.push_back(nb::cast<PyOpResult>(result).get());
|
||||
mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
|
||||
effects.effects);
|
||||
};
|
||||
|
||||
void modifiesPayload(PyMemoryEffectsInstanceList effects) {
|
||||
mlirTransformModifiesPayload(effects.effects);
|
||||
}
|
||||
|
||||
void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
|
||||
mlirTransformOnlyReadsPayload(effects.effects);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
static void populateDialectTransformSubmodule(nb::module_ &m) {
|
||||
nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
|
||||
.value("Success", MlirDiagnosedSilenceableFailureSuccess)
|
||||
.value("SilenceableFailure",
|
||||
MlirDiagnosedSilenceableFailureSilenceableFailure)
|
||||
.value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
|
||||
|
||||
AnyOpType::bind(m);
|
||||
AnyParamType::bind(m);
|
||||
AnyValueType::bind(m);
|
||||
OperationType::bind(m);
|
||||
ParamType::bind(m);
|
||||
|
||||
PyTransformRewriter::bind(m);
|
||||
PyTransformResults::bind(m);
|
||||
PyTransformState::bind(m);
|
||||
PyTransformOpInterface::bind(m);
|
||||
|
||||
m.def("only_reads_handle", onlyReadsHandle,
|
||||
"Mark operands as only reading handles.", nb::arg("operands"),
|
||||
nb::arg("effects"));
|
||||
|
||||
m.def("consumes_handle", consumesHandle,
|
||||
"Mark operands as consuming handles.", nb::arg("operands"),
|
||||
nb::arg("effects"));
|
||||
|
||||
m.def("produces_handle", producesHandle, "Mark results as producing handles.",
|
||||
nb::arg("results"), nb::arg("effects"));
|
||||
|
||||
m.def("modifies_payload", modifiesPayload,
|
||||
"Mark the transform as modifying the payload.", nb::arg("effects"));
|
||||
|
||||
m.def("only_reads_payload", onlyReadsPayload,
|
||||
"Mark the transform as only reading the payload.", nb::arg("effects"));
|
||||
}
|
||||
} // namespace transform
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
#include "mlir/Bindings/Python/Globals.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
#include "mlir/Bindings/Python/NanobindUtils.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
||||
// clang-format on
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
@@ -57,13 +56,6 @@ static size_t hash(const T &value) {
|
||||
return std::hash<T>{}(value);
|
||||
}
|
||||
|
||||
/// Helper for creating an @classmethod.
|
||||
template <class Func, typename... Args>
|
||||
static nb::object classmethod(Func f, Args... args) {
|
||||
nb::object cf = nb::cpp_function(f, args...);
|
||||
return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
|
||||
}
|
||||
|
||||
static nb::object
|
||||
createCustomDialectWrapper(const std::string &dialectNamespace,
|
||||
nb::object dialectDescriptor) {
|
||||
@@ -2289,6 +2281,44 @@ PyOpOperandList PyOpOperandList::slice(intptr_t startIndex, intptr_t length,
|
||||
return PyOpOperandList(operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
/// A list of OpOperands. Internally, these are stored as consecutive elements,
|
||||
/// random access is cheap. The (returned) OpOperand list is associated with the
|
||||
/// operation whose operands these are, and thus extends the lifetime of this
|
||||
/// operation.
|
||||
class PyOpOperands : public Sliceable<PyOpOperands, PyOpOperand> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "OpOperands";
|
||||
using SliceableT = Sliceable<PyOpOperandList, PyOpOperand>;
|
||||
|
||||
PyOpOperands(PyOperationRef operation, intptr_t startIndex = 0,
|
||||
intptr_t length = -1, intptr_t step = 1)
|
||||
: Sliceable(startIndex,
|
||||
length == -1 ? mlirOperationGetNumOperands(operation->get())
|
||||
: length,
|
||||
step),
|
||||
operation(operation) {}
|
||||
|
||||
private:
|
||||
/// Give the parent CRTP class access to hook implementations below.
|
||||
friend class Sliceable<PyOpOperands, PyOpOperand>;
|
||||
|
||||
intptr_t getRawNumElements() {
|
||||
operation->checkValid();
|
||||
return mlirOperationGetNumOperands(operation->get());
|
||||
}
|
||||
|
||||
PyOpOperand getRawElement(intptr_t pos) {
|
||||
MlirOpOperand opOperand = mlirOperationGetOpOperand(operation->get(), pos);
|
||||
return PyOpOperand(opOperand);
|
||||
}
|
||||
|
||||
PyOpOperands slice(intptr_t startIndex, intptr_t length, intptr_t step) {
|
||||
return PyOpOperands(operation, startIndex, length, step);
|
||||
}
|
||||
|
||||
PyOperationRef operation;
|
||||
};
|
||||
|
||||
PyOpSuccessors::PyOpSuccessors(PyOperationRef operation, intptr_t startIndex,
|
||||
intptr_t length, intptr_t step)
|
||||
: Sliceable(startIndex,
|
||||
@@ -3669,6 +3699,12 @@ void populateIRCore(nb::module_ &m) {
|
||||
return PyOpOperandList(self.getOperation().getRef());
|
||||
},
|
||||
"Returns the list of operation operands.")
|
||||
.def_prop_ro(
|
||||
"op_operands",
|
||||
[](PyOperationBase &self) {
|
||||
return PyOpOperands(self.getOperation().getRef());
|
||||
},
|
||||
"Returns the list of op operands.")
|
||||
.def_prop_ro(
|
||||
"regions",
|
||||
[](PyOperationBase &self) {
|
||||
@@ -4950,6 +4986,7 @@ void populateIRCore(nb::module_ &m) {
|
||||
PyOpAttributeMap::bind(m);
|
||||
PyOpOperandIterator::bind(m);
|
||||
PyOpOperandList::bind(m);
|
||||
PyOpOperands::bind(m);
|
||||
PyOpResultList::bind(m);
|
||||
PyOpSuccessors::bind(m);
|
||||
PyRegionIterator::bind(m);
|
||||
|
||||
@@ -12,30 +12,18 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "IRInterfaces.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Interfaces.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
constexpr static const char *constructorDoc =
|
||||
R"(Creates an interface from a given operation/opview object or from a
|
||||
subclass of OpView. Raises ValueError if the operation does not implement the
|
||||
interface.)";
|
||||
|
||||
constexpr static const char *operationDoc =
|
||||
R"(Returns an Operation for which the interface was constructed.)";
|
||||
|
||||
constexpr static const char *opviewDoc =
|
||||
R"(Returns an OpView subclass _instance_ for which the interface was
|
||||
constructed)";
|
||||
|
||||
constexpr static const char *inferReturnTypesDoc =
|
||||
R"(Given the arguments required to build an operation, attempts to infer
|
||||
its return types. Raises ValueError on failure.)";
|
||||
@@ -124,119 +112,6 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
|
||||
|
||||
} // namespace
|
||||
|
||||
/// CRTP base class for Python classes representing MLIR Op interfaces.
|
||||
/// Interface hierarchies are flat so no base class is expected here. The
|
||||
/// derived class is expected to define the following static fields:
|
||||
/// - `const char *pyClassName` - the name of the Python class to create;
|
||||
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
|
||||
/// of the interface.
|
||||
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
|
||||
/// interface-specific methods.
|
||||
///
|
||||
/// An interface class may be constructed from either an Operation/OpView object
|
||||
/// or from a subclass of OpView. In the latter case, only the static interface
|
||||
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
|
||||
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
|
||||
/// method to check whether the interface object was constructed from a class or
|
||||
/// an operation/opview instance. The `getOpName` always succeeds and returns a
|
||||
/// canonical name of the operation suitable for lookups.
|
||||
template <typename ConcreteIface>
|
||||
class PyConcreteOpInterface {
|
||||
protected:
|
||||
using ClassTy = nb::class_<ConcreteIface>;
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
|
||||
public:
|
||||
/// Constructs an interface instance from an object that is either an
|
||||
/// operation or a subclass of OpView. In the latter case, only the static
|
||||
/// methods of the interface are accessible to the caller.
|
||||
PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
|
||||
: obj(std::move(object)) {
|
||||
try {
|
||||
operation = &nb::cast<PyOperation &>(obj);
|
||||
} catch (nb::cast_error &) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
try {
|
||||
operation = &nb::cast<PyOpView &>(obj).getOperation();
|
||||
} catch (nb::cast_error &) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (operation != nullptr) {
|
||||
if (!mlirOperationImplementsInterface(*operation,
|
||||
ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
|
||||
MlirIdentifier identifier = mlirOperationGetName(*operation);
|
||||
MlirStringRef stringRef = mlirIdentifierStr(identifier);
|
||||
opName = std::string(stringRef.data, stringRef.length);
|
||||
} else {
|
||||
try {
|
||||
opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
|
||||
} catch (nb::cast_error &) {
|
||||
throw nb::type_error(
|
||||
"Op interface does not refer to an operation or OpView class");
|
||||
}
|
||||
|
||||
if (!mlirOperationImplementsInterfaceStatic(
|
||||
mlirStringRefCreate(opName.data(), opName.length()),
|
||||
context.resolve().get(), ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the Python bindings for this class in the given module.
|
||||
static void bind(nb::module_ &m) {
|
||||
nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
|
||||
cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
|
||||
nb::arg("context") = nb::none(), constructorDoc)
|
||||
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
|
||||
operationDoc)
|
||||
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
|
||||
ConcreteIface::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Hook for derived classes to add class-specific bindings.
|
||||
static void bindDerived(ClassTy &cls) {}
|
||||
|
||||
/// Returns `true` if this object was constructed from a subclass of OpView
|
||||
/// rather than from an operation instance.
|
||||
bool isStatic() { return operation == nullptr; }
|
||||
|
||||
/// Returns the operation instance from which this object was constructed.
|
||||
/// Throws a type error if this object was constructed from a subclass of
|
||||
/// OpView.
|
||||
nb::typed<nb::object, PyOperation> getOperationObject() {
|
||||
if (operation == nullptr)
|
||||
throw nb::type_error("Cannot get an operation from a static interface");
|
||||
return operation->getRef().releaseObject();
|
||||
}
|
||||
|
||||
/// Returns the opview of the operation instance from which this object was
|
||||
/// constructed. Throws a type error if this object was constructed form a
|
||||
/// subclass of OpView.
|
||||
nb::typed<nb::object, PyOpView> getOpView() {
|
||||
if (operation == nullptr)
|
||||
throw nb::type_error("Cannot get an opview from a static interface");
|
||||
return operation->createOpView();
|
||||
}
|
||||
|
||||
/// Returns the canonical name of the operation this interface is constructed
|
||||
/// from.
|
||||
const std::string &getOpName() { return opName; }
|
||||
|
||||
private:
|
||||
PyOperation *operation = nullptr;
|
||||
std::string opName;
|
||||
nb::object obj;
|
||||
};
|
||||
|
||||
/// Python wrapper for InferTypeOpInterface. This interface has only static
|
||||
/// methods.
|
||||
class PyInferTypeOpInterface
|
||||
@@ -462,10 +337,74 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Wrapper around the MemoryEffectsOpInterface.
|
||||
class PyMemoryEffectsOpInterface
|
||||
: public PyConcreteOpInterface<PyMemoryEffectsOpInterface> {
|
||||
public:
|
||||
using PyConcreteOpInterface<
|
||||
PyMemoryEffectsOpInterface>::PyConcreteOpInterface;
|
||||
|
||||
constexpr static const char *pyClassName = "MemoryEffectsOpInterface";
|
||||
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
||||
&mlirMemoryEffectsOpInterfaceTypeID;
|
||||
|
||||
/// Attach a new MemoryEffectsOpInterface FallbackModel to the named
|
||||
/// operation. The FallbackModel acts as a trampoline for callbacks on the
|
||||
/// Python class.
|
||||
static void attach(nb::object &target, const std::string &opName,
|
||||
DefaultingPyMlirContext ctx) {
|
||||
MlirMemoryEffectsOpInterfaceCallbacks callbacks;
|
||||
callbacks.userData = target.ptr();
|
||||
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
|
||||
callbacks.construct = nullptr;
|
||||
callbacks.destruct = [](void *userData) {
|
||||
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
||||
};
|
||||
callbacks.getEffects = [](MlirOperation op,
|
||||
MlirMemoryEffectInstancesList effects,
|
||||
void *userData) {
|
||||
nb::handle pyClass(static_cast<PyObject *>(userData));
|
||||
|
||||
// Get the 'get_effects' method from the Python class.
|
||||
auto pyGetEffects =
|
||||
nb::cast<nb::callable>(nb::getattr(pyClass, "get_effects"));
|
||||
|
||||
PyMemoryEffectsInstanceList effectsWrapper{effects};
|
||||
|
||||
PyMlirContextRef context =
|
||||
PyMlirContext::forContext(mlirOperationGetContext(op));
|
||||
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
||||
|
||||
// Invoke `pyClass.get_effects(op, effects)`.
|
||||
pyGetEffects(opview, effectsWrapper);
|
||||
};
|
||||
|
||||
mlirMemoryEffectsOpInterfaceAttachFallbackModel(
|
||||
ctx->get(), wrap(StringRef(opName.c_str())), callbacks);
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.attr("attach") = classmethod(
|
||||
[](const nb::object &cls, const nb::object &opName, nb::object target,
|
||||
DefaultingPyMlirContext context) {
|
||||
if (target.is_none())
|
||||
target = cls;
|
||||
return attach(target, nb::cast<std::string>(opName), context);
|
||||
},
|
||||
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
|
||||
nb::arg("target").none() = nb::none(),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Attach the interface subclass to the given operation name.");
|
||||
}
|
||||
};
|
||||
|
||||
void populateIRInterfaces(nb::module_ &m) {
|
||||
PyInferTypeOpInterface::bind(m);
|
||||
PyShapedTypeComponents::bind(m);
|
||||
nb::class_<PyMemoryEffectsInstanceList>(m, "MemoryEffectInstancesList");
|
||||
|
||||
PyInferShapedTypeOpInterface::bind(m);
|
||||
PyInferTypeOpInterface::bind(m);
|
||||
PyMemoryEffectsOpInterface::bind(m);
|
||||
PyShapedTypeComponents::bind(m);
|
||||
}
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
} // namespace python
|
||||
|
||||
152
mlir/lib/Bindings/Python/IRInterfaces.h
Normal file
152
mlir/lib/Bindings/Python/IRInterfaces.h
Normal file
@@ -0,0 +1,152 @@
|
||||
//===- IRInterfaces.h - IR Interfaces for Python Bindings -------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_BINDINGS_PYTHON_IRINTERFACES_H
|
||||
#define MLIR_BINDINGS_PYTHON_IRINTERFACES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Interfaces.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
|
||||
constexpr static const char *constructorDoc =
|
||||
R"(Creates an interface from a given operation/opview object or from a
|
||||
subclass of OpView. Raises ValueError if the operation does not implement the
|
||||
interface.)";
|
||||
|
||||
constexpr static const char *operationDoc =
|
||||
R"(Returns an Operation for which the interface was constructed.)";
|
||||
|
||||
constexpr static const char *opviewDoc =
|
||||
R"(Returns an OpView subclass _instance_ for which the interface was
|
||||
constructed)";
|
||||
|
||||
/// CRTP base class for Python classes representing MLIR Op interfaces.
|
||||
/// Interface hierarchies are flat so no base class is expected here. The
|
||||
/// derived class is expected to define the following static fields:
|
||||
/// - `const char *pyClassName` - the name of the Python class to create;
|
||||
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
|
||||
/// of the interface.
|
||||
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
|
||||
/// interface-specific methods.
|
||||
///
|
||||
/// An interface class may be constructed from either an Operation/OpView object
|
||||
/// or from a subclass of OpView. In the latter case, only the static interface
|
||||
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
|
||||
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
|
||||
/// method to check whether the interface object was constructed from a class or
|
||||
/// an operation/opview instance. The `getOpName` always succeeds and returns a
|
||||
/// canonical name of the operation suitable for lookups.
|
||||
template <typename ConcreteIface>
|
||||
class PyConcreteOpInterface {
|
||||
protected:
|
||||
using ClassTy = nanobind::class_<ConcreteIface>;
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
|
||||
public:
|
||||
/// Constructs an interface instance from an object that is either an
|
||||
/// operation or a subclass of OpView. In the latter case, only the static
|
||||
/// methods of the interface are accessible to the caller.
|
||||
PyConcreteOpInterface(nanobind::object object,
|
||||
DefaultingPyMlirContext context)
|
||||
: obj(std::move(object)) {
|
||||
if (!nanobind::try_cast<PyOperation *>(obj, operation)) {
|
||||
PyOpView *opview;
|
||||
if (nanobind::try_cast<PyOpView *>(obj, opview)) {
|
||||
operation = &opview->getOperation();
|
||||
};
|
||||
}
|
||||
|
||||
if (operation != nullptr) {
|
||||
if (!mlirOperationImplementsInterface(*operation,
|
||||
ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw nanobind::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
|
||||
MlirIdentifier identifier = mlirOperationGetName(*operation);
|
||||
MlirStringRef stringRef = mlirIdentifierStr(identifier);
|
||||
opName = std::string(stringRef.data, stringRef.length);
|
||||
} else {
|
||||
if (!nanobind::try_cast<std::string>(obj.attr("OPERATION_NAME"), opName))
|
||||
throw nanobind::type_error(
|
||||
"Op interface does not refer to an operation or OpView class");
|
||||
|
||||
if (!mlirOperationImplementsInterfaceStatic(
|
||||
mlirStringRefCreate(opName.data(), opName.length()),
|
||||
context.resolve().get(), ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw nanobind::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the Python bindings for this class in the given module.
|
||||
static void bind(nanobind::module_ &m) {
|
||||
nanobind::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
|
||||
cls.def(nanobind::init<nanobind::object, DefaultingPyMlirContext>(),
|
||||
nanobind::arg("object"),
|
||||
nanobind::arg("context") = nanobind::none(), constructorDoc)
|
||||
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
|
||||
operationDoc)
|
||||
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
|
||||
ConcreteIface::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Hook for derived classes to add class-specific bindings.
|
||||
static void bindDerived(ClassTy &cls) {}
|
||||
|
||||
/// Returns `true` if this object was constructed from a subclass of OpView
|
||||
/// rather than from an operation instance.
|
||||
bool isStatic() { return operation == nullptr; }
|
||||
|
||||
/// Returns the operation instance from which this object was constructed.
|
||||
/// Throws a type error if this object was constructed from a subclass of
|
||||
/// OpView.
|
||||
nanobind::typed<nanobind::object, PyOperation> getOperationObject() {
|
||||
if (operation == nullptr)
|
||||
throw nanobind::type_error(
|
||||
"Cannot get an operation from a static interface");
|
||||
return operation->getRef().releaseObject();
|
||||
}
|
||||
|
||||
/// Returns the opview of the operation instance from which this object was
|
||||
/// constructed. Throws a type error if this object was constructed form a
|
||||
/// subclass of OpView.
|
||||
nanobind::typed<nanobind::object, PyOpView> getOpView() {
|
||||
if (operation == nullptr)
|
||||
throw nanobind::type_error(
|
||||
"Cannot get an opview from a static interface");
|
||||
return operation->createOpView();
|
||||
}
|
||||
|
||||
/// Returns the canonical name of the operation this interface is constructed
|
||||
/// from.
|
||||
const std::string &getOpName() { return opName; }
|
||||
|
||||
private:
|
||||
PyOperation *operation = nullptr;
|
||||
std::string opName;
|
||||
nanobind::object obj;
|
||||
};
|
||||
|
||||
struct PyMemoryEffectsInstanceList {
|
||||
MlirMemoryEffectInstancesList effects;
|
||||
};
|
||||
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_BINDINGS_PYTHON_IRINTERFACES_H
|
||||
@@ -8,15 +8,12 @@
|
||||
|
||||
#include "Rewrite.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Rewrite.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/Globals.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
// clang-format off
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
||||
// clang-format on
|
||||
#include "mlir/Config/mlir-config.h"
|
||||
#include "nanobind/nanobind.h"
|
||||
#include <type_traits>
|
||||
@@ -30,38 +27,12 @@ namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
|
||||
class PyPatternRewriter {
|
||||
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
|
||||
public:
|
||||
static constexpr const char *pyClassName = "PatternRewriter";
|
||||
|
||||
PyPatternRewriter(MlirPatternRewriter rewriter)
|
||||
: base(mlirPatternRewriterAsBase(rewriter)),
|
||||
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
|
||||
|
||||
PyInsertionPoint getInsertionPoint() const {
|
||||
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
|
||||
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
|
||||
|
||||
if (mlirOperationIsNull(op)) {
|
||||
MlirOperation owner = mlirBlockGetParentOperation(block);
|
||||
auto parent = PyOperation::forOperation(ctx, owner);
|
||||
return PyInsertionPoint(PyBlock(parent, block));
|
||||
}
|
||||
|
||||
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
|
||||
}
|
||||
|
||||
void replaceOp(MlirOperation op, MlirOperation newOp) {
|
||||
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
|
||||
}
|
||||
|
||||
void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
|
||||
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
|
||||
}
|
||||
|
||||
void eraseOp(const PyOperation &op) { mlirRewriterBaseEraseOp(base, op); }
|
||||
|
||||
private:
|
||||
MlirRewriterBase base;
|
||||
PyMlirContextRef ctx;
|
||||
: PyRewriterBase(mlirPatternRewriterAsBase(rewriter)) {}
|
||||
};
|
||||
|
||||
class PyConversionPatternRewriter : PyPatternRewriter {
|
||||
@@ -514,29 +485,8 @@ void populateRewriteSubmodule(nb::module_ &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of the PatternRewriter
|
||||
//----------------------------------------------------------------------------
|
||||
nb::class_<PyPatternRewriter>(m, "PatternRewriter")
|
||||
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
|
||||
"The current insertion point of the PatternRewriter.")
|
||||
.def(
|
||||
"replace_op",
|
||||
[](PyPatternRewriter &self, PyOperationBase &op,
|
||||
PyOperationBase &newOp) {
|
||||
self.replaceOp(op.getOperation(), newOp.getOperation());
|
||||
},
|
||||
"Replace an operation with a new operation.", nb::arg("op"),
|
||||
nb::arg("new_op"))
|
||||
.def(
|
||||
"replace_op",
|
||||
[](PyPatternRewriter &self, PyOperationBase &op,
|
||||
const std::vector<PyValue> &values) {
|
||||
std::vector<MlirValue> values_(values.size());
|
||||
std::copy(values.begin(), values.end(), values_.begin());
|
||||
self.replaceOp(op.getOperation(), values_);
|
||||
},
|
||||
"Replace an operation with a list of values.", nb::arg("op"),
|
||||
nb::arg("values"))
|
||||
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
|
||||
nb::arg("op"));
|
||||
|
||||
PyPatternRewriter::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of the RewritePatternSet
|
||||
|
||||
@@ -9,13 +9,74 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
|
||||
#define MLIR_BINDINGS_PYTHON_REWRITE_H
|
||||
|
||||
#include "mlir/Bindings/Python/NanobindUtils.h"
|
||||
#include "mlir-c/Rewrite.h"
|
||||
#include "mlir/Bindings/Python/IRCore.h"
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
||||
void populateRewriteSubmodule(nanobind::module_ &m);
|
||||
}
|
||||
|
||||
/// CRTP Base class for rewriter wrappers.
|
||||
template <typename DerivedTy>
|
||||
class MLIR_PYTHON_API_EXPORTED PyRewriterBase {
|
||||
public:
|
||||
PyRewriterBase(MlirRewriterBase rewriter)
|
||||
: base(rewriter),
|
||||
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
|
||||
|
||||
PyInsertionPoint getInsertionPoint() const {
|
||||
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
|
||||
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
|
||||
|
||||
if (mlirOperationIsNull(op)) {
|
||||
MlirOperation owner = mlirBlockGetParentOperation(block);
|
||||
auto parent = PyOperation::forOperation(ctx, owner);
|
||||
return PyInsertionPoint(PyBlock(parent, block));
|
||||
}
|
||||
|
||||
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
|
||||
}
|
||||
|
||||
static void bind(nanobind::module_ &m) {
|
||||
nanobind::class_<DerivedTy>(m, DerivedTy::pyClassName)
|
||||
.def_prop_ro("ip", &PyRewriterBase::getInsertionPoint,
|
||||
"The current insertion point of the PatternRewriter.")
|
||||
.def(
|
||||
"replace_op",
|
||||
[](DerivedTy &self, PyOperationBase &op, PyOperationBase &newOp) {
|
||||
mlirRewriterBaseReplaceOpWithOperation(
|
||||
self.base, op.getOperation(), newOp.getOperation());
|
||||
},
|
||||
"Replace an operation with a new operation.", nanobind::arg("op"),
|
||||
nanobind::arg("new_op"))
|
||||
.def(
|
||||
"replace_op",
|
||||
[](DerivedTy &self, PyOperationBase &op,
|
||||
const std::vector<PyValue> &values) {
|
||||
std::vector<MlirValue> values_(values.size());
|
||||
std::copy(values.begin(), values.end(), values_.begin());
|
||||
mlirRewriterBaseReplaceOpWithValues(
|
||||
self.base, op.getOperation(), values_.size(), values_.data());
|
||||
},
|
||||
"Replace an operation with a list of values.", nanobind::arg("op"),
|
||||
nanobind::arg("values"))
|
||||
.def(
|
||||
"erase_op",
|
||||
[](DerivedTy &self, PyOperationBase &op) {
|
||||
mlirRewriterBaseEraseOp(self.base, op.getOperation());
|
||||
},
|
||||
"Erase an operation.", nanobind::arg("op"));
|
||||
}
|
||||
|
||||
private:
|
||||
MlirRewriterBase base;
|
||||
PyMlirContextRef ctx;
|
||||
};
|
||||
|
||||
void MLIR_PYTHON_API_EXPORTED populateRewriteSubmodule(nanobind::module_ &m);
|
||||
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -8,9 +8,14 @@
|
||||
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/CAPI/Dialect/Transform.h"
|
||||
#include "mlir/CAPI/Interfaces.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Rewrite.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -126,3 +131,210 @@ MlirStringRef mlirTransformParamTypeGetName(void) {
|
||||
MlirType mlirTransformParamTypeGetType(MlirType type) {
|
||||
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformRewriter
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
|
||||
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
|
||||
mlir::transform::TransformRewriter *t = unwrap(rewriter);
|
||||
mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
|
||||
return wrap(base);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformResults
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
|
||||
intptr_t numOps, MlirOperation *ops) {
|
||||
SmallVector<Operation *> opsVec;
|
||||
opsVec.reserve(numOps);
|
||||
for (intptr_t i = 0; i < numOps; ++i)
|
||||
opsVec.push_back(unwrap(ops[i]));
|
||||
unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
|
||||
}
|
||||
|
||||
void mlirTransformResultsSetValues(MlirTransformResults results,
|
||||
MlirValue result, intptr_t numValues,
|
||||
MlirValue *values) {
|
||||
SmallVector<Value> valuesVec;
|
||||
valuesVec.reserve(numValues);
|
||||
for (intptr_t i = 0; i < numValues; ++i)
|
||||
valuesVec.push_back(unwrap(values[i]));
|
||||
unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
|
||||
}
|
||||
|
||||
void mlirTransformResultsSetParams(MlirTransformResults results,
|
||||
MlirValue result, intptr_t numParams,
|
||||
MlirAttribute *params) {
|
||||
SmallVector<Attribute> paramsVec;
|
||||
paramsVec.reserve(numParams);
|
||||
for (intptr_t i = 0; i < numParams; ++i)
|
||||
paramsVec.push_back(unwrap(params[i]));
|
||||
unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformState
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
void mlirTransformStateForEachPayloadOp(MlirTransformState state,
|
||||
MlirValue value,
|
||||
MlirOperationCallback callback,
|
||||
void *userData) {
|
||||
for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
|
||||
callback(wrap(op), userData);
|
||||
}
|
||||
|
||||
void mlirTransformStateForEachPayloadValue(MlirTransformState state,
|
||||
MlirValue value,
|
||||
MlirValueCallback callback,
|
||||
void *userData) {
|
||||
for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
|
||||
callback(wrap(val), userData);
|
||||
}
|
||||
|
||||
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
|
||||
MlirAttributeCallback callback,
|
||||
void *userData) {
|
||||
for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
|
||||
callback(wrap(attr), userData);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// TransformOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirTransformOpInterfaceTypeID(void) {
|
||||
return wrap(transform::TransformOpInterface::getInterfaceID());
|
||||
}
|
||||
|
||||
/// Fallback model for the TransformOpInterface that uses C API callbacks.
|
||||
class TransformOpInterfaceFallbackModel
|
||||
: public mlir::transform::TransformOpInterface::FallbackModel<
|
||||
TransformOpInterfaceFallbackModel> {
|
||||
public:
|
||||
/// Sets the callbacks that this FallbackModel will use.
|
||||
/// NB: the callbacks can only be set through this method as the
|
||||
/// RegisteredOperationName::attachInterface mechanism default-constructs
|
||||
/// the FallbackModel without being able to provide arguments.
|
||||
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
|
||||
this->callbacks = callbacks;
|
||||
}
|
||||
|
||||
~TransformOpInterfaceFallbackModel() {
|
||||
if (callbacks.destruct)
|
||||
callbacks.destruct(callbacks.userData);
|
||||
}
|
||||
|
||||
static TypeID getInterfaceID() {
|
||||
return transform::TransformOpInterface::getInterfaceID();
|
||||
}
|
||||
|
||||
static bool classof(const mlir::transform::detail::
|
||||
TransformOpInterfaceInterfaceTraits::Concept *op) {
|
||||
// Enable casting back to the FallbackModel from the Interface. This is
|
||||
// necessary as attachInterface(...) default-constructs the FallbackModel
|
||||
// without being able to pass in the callbacks and returns just the Concept.
|
||||
return true;
|
||||
}
|
||||
|
||||
::mlir::DiagnosedSilenceableFailure
|
||||
apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::transform::TransformResults &transformResults,
|
||||
::mlir::transform::TransformState &state) const {
|
||||
assert(callbacks.apply && "apply callback not set");
|
||||
|
||||
MlirDiagnosedSilenceableFailure status =
|
||||
callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
|
||||
wrap(&state), callbacks.userData);
|
||||
|
||||
switch (status) {
|
||||
case MlirDiagnosedSilenceableFailureSuccess:
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
case MlirDiagnosedSilenceableFailureSilenceableFailure:
|
||||
// TODO: enable passing diagnostic info from C API to C++ API.
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(
|
||||
*(op->emitError()
|
||||
<< "TransformOpInterfaceFallbackModel: silenceable failure")
|
||||
.getUnderlyingDiagnostic()));
|
||||
case MlirDiagnosedSilenceableFailureDefiniteFailure:
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
llvm_unreachable("unknown transform status");
|
||||
}
|
||||
|
||||
bool allowsRepeatedHandleOperands(Operation *op) const {
|
||||
assert(callbacks.allowsRepeatedHandleOperands &&
|
||||
"allowsRepeatedHandleOperands callback not set");
|
||||
return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
|
||||
}
|
||||
|
||||
private:
|
||||
MlirTransformOpInterfaceCallbacks callbacks;
|
||||
};
|
||||
|
||||
/// Attach a TransformOpInterface FallbackModel to the given named operation.
|
||||
/// The FallbackModel uses the provided callbacks to implement the interface.
|
||||
void mlirTransformOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirTransformOpInterfaceCallbacks callbacks) {
|
||||
// Look up the operation definition in the context.
|
||||
std::optional<RegisteredOperationName> opInfo =
|
||||
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
|
||||
|
||||
assert(opInfo.has_value() && "operation not found in context");
|
||||
|
||||
// NB: the following default-constructs the FallbackModel _without_ being able
|
||||
// to provide arguments.
|
||||
opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
|
||||
// Cast to get the underlying FallbackModel and set the callbacks.
|
||||
auto *model = cast<TransformOpInterfaceFallbackModel>(
|
||||
opInfo->getInterface<TransformOpInterfaceFallbackModel>());
|
||||
|
||||
assert(model && "Failed to get TransformOpInterfaceFallbackModel");
|
||||
model->setCallbacks(callbacks);
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MemoryEffectsOpInterface helpers
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
/// Set the effect for the operands to only read the transform handles.
|
||||
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
|
||||
MlirMemoryEffectInstancesList effects) {
|
||||
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
|
||||
transform::onlyReadsHandle(operandArray, *unwrap(effects));
|
||||
}
|
||||
|
||||
/// Set the effect for the operands to consuming the transform handles.
|
||||
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
|
||||
MlirMemoryEffectInstancesList effects) {
|
||||
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
|
||||
transform::consumesHandle(operandArray, *unwrap(effects));
|
||||
}
|
||||
|
||||
/// Set the effect for the results to that they produce transform handles.
|
||||
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
|
||||
MlirMemoryEffectInstancesList effects) {
|
||||
// NB: calling `producesHandle()` `numResults` as we cannot cast array of
|
||||
// `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
|
||||
// to Python). `producesHandle` iterates over the given `ResultRange` anyway.
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
|
||||
for (intptr_t i = 0; i < numResults; ++i) {
|
||||
auto opResult = cast<OpResult>(unwrap(results[i]));
|
||||
transform::producesHandle(ResultRange(opResult), effectList);
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the effect of potentially modifying payload IR.
|
||||
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
|
||||
transform::modifiesPayload(*unwrap(effects));
|
||||
}
|
||||
|
||||
/// Set the effect of potentially reading payload IR.
|
||||
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
|
||||
transform::onlyReadsPayload(*unwrap(effects));
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Parser/Parser.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/ThreadPool.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
@@ -714,6 +713,10 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
|
||||
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
|
||||
}
|
||||
|
||||
MlirOpOperand mlirOperationGetOpOperand(MlirOperation op, intptr_t pos) {
|
||||
return wrap(&unwrap(op)->getOpOperand(static_cast<unsigned>(pos)));
|
||||
}
|
||||
|
||||
void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
|
||||
MlirValue newValue) {
|
||||
unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
|
||||
|
||||
@@ -167,3 +167,73 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
|
||||
}
|
||||
return mlirLogicalResultSuccess();
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// MemoryEffectOpInterface
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirMemoryEffectsOpInterfaceTypeID() {
|
||||
return wrap(MemoryEffectOpInterface::getInterfaceID());
|
||||
}
|
||||
|
||||
/// Fallback model for the MemoryEffectsOpInterface that uses C API callbacks.
|
||||
class MemoryEffectOpInterfaceFallbackModel
|
||||
: public mlir::MemoryEffectOpInterface::FallbackModel<
|
||||
MemoryEffectOpInterfaceFallbackModel> {
|
||||
public:
|
||||
/// Sets the callbacks that this FallbackModel will use.
|
||||
/// NB: the callbacks can only be set through this method as the
|
||||
/// RegisteredOperationName::attachInterface mechanism default-constructs
|
||||
/// the FallbackModel without being able to provide arguments.
|
||||
void setCallbacks(MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
|
||||
this->callbacks = callbacks;
|
||||
}
|
||||
|
||||
~MemoryEffectOpInterfaceFallbackModel() {
|
||||
if (callbacks.destruct)
|
||||
callbacks.destruct(callbacks.userData);
|
||||
}
|
||||
|
||||
static TypeID getInterfaceID() {
|
||||
return MemoryEffectOpInterface::getInterfaceID();
|
||||
}
|
||||
|
||||
static bool classof(const mlir::MemoryEffectOpInterface::Concept *op) {
|
||||
// Enable casting back to the FallbackModel from the Interface. This is
|
||||
// necessary as attachInterface(...) default-constructs the FallbackModel
|
||||
// without being able to pass in the callbacks and returns just the Concept.
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
getEffects(Operation *op,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) const {
|
||||
assert(callbacks.getEffects && "getEffects callback not set");
|
||||
MlirMemoryEffectInstancesList cEffects = wrap(&effects);
|
||||
callbacks.getEffects(wrap(op), cEffects, callbacks.userData);
|
||||
}
|
||||
|
||||
private:
|
||||
MlirMemoryEffectsOpInterfaceCallbacks callbacks;
|
||||
};
|
||||
|
||||
/// Attach a MemoryEffectsOpInterface FallbackModel to the given named op.
|
||||
/// The FallbackModel uses the provided callbacks to implement the interface.
|
||||
void mlirMemoryEffectsOpInterfaceAttachFallbackModel(
|
||||
MlirContext ctx, MlirStringRef opName,
|
||||
MlirMemoryEffectsOpInterfaceCallbacks callbacks) {
|
||||
// Look up the operation definition in the context
|
||||
std::optional<RegisteredOperationName> opInfo =
|
||||
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
|
||||
|
||||
assert(opInfo.has_value() && "operation not found in context");
|
||||
|
||||
// NB: the following default-constructs the FallbackModel _without_ being able
|
||||
// to provide arguments.
|
||||
opInfo->attachInterface<MemoryEffectOpInterfaceFallbackModel>();
|
||||
// Cast to get the underlying FallbackModel and set the callbacks.
|
||||
auto *model = cast<MemoryEffectOpInterfaceFallbackModel>(
|
||||
opInfo->getInterface<MemoryEffectOpInterfaceFallbackModel>());
|
||||
assert(model && "Failed to get MemoryEffectOpInterfaceFallbackModel");
|
||||
model->setCallbacks(callbacks);
|
||||
}
|
||||
|
||||
@@ -687,8 +687,10 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind
|
||||
ROOT_DIR "${PYTHON_SOURCE_DIR}"
|
||||
SOURCES
|
||||
DialectTransform.cpp
|
||||
Rewrite.h
|
||||
PRIVATE_LINK_LIBS
|
||||
LLVMSupport
|
||||
MLIRPythonExtension.Core
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIIR
|
||||
MLIRCAPITransformDialect
|
||||
|
||||
@@ -242,6 +242,7 @@ def _site_initialize():
|
||||
Sequence.register(ir.BlockPredecessors)
|
||||
Sequence.register(ir.OperationList)
|
||||
Sequence.register(ir.OpOperandList)
|
||||
Sequence.register(ir.OpOperands)
|
||||
Sequence.register(ir.OpResultList)
|
||||
Sequence.register(ir.OpSuccessors)
|
||||
Sequence.register(ir.RegionSequence)
|
||||
|
||||
@@ -29,6 +29,8 @@ __all__ = [
|
||||
"Dialect",
|
||||
"Operand",
|
||||
"Result",
|
||||
"register_dialect",
|
||||
"register_operation",
|
||||
"Region",
|
||||
"Operation",
|
||||
]
|
||||
@@ -36,6 +38,8 @@ __all__ = [
|
||||
Operand = ir.Value
|
||||
Result = ir.OpResult
|
||||
Region = ir.Region
|
||||
register_dialect = _cext.register_dialect
|
||||
register_operation = _cext.register_operation
|
||||
|
||||
|
||||
class ConstraintLoweringContext:
|
||||
@@ -203,6 +207,12 @@ class Operation(ir.OpView):
|
||||
Use `Dialect` and `.Operation` of `Dialect` subclasses instead.
|
||||
"""
|
||||
|
||||
def __init__(*args, **kwargs):
|
||||
raise TypeError(
|
||||
"This class is a template and cannot be instantiated directly. "
|
||||
"Please use a subclass that defines the operation."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(
|
||||
cls, *, name: str | None = None, traits: list[type] | None = None, **kwargs
|
||||
@@ -507,22 +517,21 @@ class Dialect(ir.Dialect):
|
||||
return m
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> None:
|
||||
if hasattr(cls, "_mlir_module"):
|
||||
raise RuntimeError(f"Dialect {cls.name} is already loaded.")
|
||||
|
||||
mlir_module = cls._emit_module()
|
||||
def load(cls, register=True, reload=False) -> None:
|
||||
if hasattr(cls, "_mlir_module") and not reload:
|
||||
return
|
||||
|
||||
cls._mlir_module = cls._emit_module()
|
||||
pm = PassManager()
|
||||
pm.add("canonicalize, cse")
|
||||
pm.run(mlir_module.operation)
|
||||
pm.run(cls._mlir_module.operation)
|
||||
|
||||
irdl.load_dialects(mlir_module)
|
||||
irdl.load_dialects(cls._mlir_module)
|
||||
|
||||
_cext.register_dialect(cls)
|
||||
if register:
|
||||
register_dialect(cls)
|
||||
|
||||
for op in cls.operations:
|
||||
op._attach_traits()
|
||||
_cext.register_operation(cls)(op)
|
||||
|
||||
cls._mlir_module = mlir_module
|
||||
register_dialect_operation = register_operation(cls)
|
||||
for op in cls.operations:
|
||||
op._attach_traits()
|
||||
register_dialect_operation(op)
|
||||
|
||||
500
mlir/test/python/dialects/transform_op_interface.py
Normal file
500
mlir/test/python/dialects/transform_op_interface.py
Normal file
@@ -0,0 +1,500 @@
|
||||
# RUN: env PYTHONUNBUFFERED=1 %PYTHON %s 2>&1 | FileCheck %s
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from mlir import ir
|
||||
from mlir.dialects import index, transform, func, arith, ext
|
||||
from mlir.dialects.transform import (
|
||||
DiagnosedSilenceableFailure,
|
||||
AnyOpType,
|
||||
AnyValueType,
|
||||
AnyParamType,
|
||||
structured,
|
||||
interpreter,
|
||||
)
|
||||
|
||||
|
||||
@ext.register_dialect
|
||||
class MyTransform(ext.Dialect, name="my_transform"):
|
||||
pass
|
||||
|
||||
|
||||
def run(emit_schedule):
|
||||
print(f"Test: {emit_schedule.__name__}")
|
||||
with ir.Context() as ctx, ir.Location.unknown():
|
||||
payload = emit_payload()
|
||||
|
||||
MyTransform.load(register=False, reload=True)
|
||||
|
||||
GetNamedAttributeOp.attach_interface_impls(ctx)
|
||||
PrintParamOp.attach_interface_impls(ctx)
|
||||
|
||||
# NB: Other newly defined my_transform ops have their interfaces attached
|
||||
# in their respective test functions.
|
||||
schedule = emit_schedule()
|
||||
|
||||
interpreter.apply_named_sequence(
|
||||
payload,
|
||||
_named_seq := schedule.operation.regions[0].blocks[0].operations[0],
|
||||
schedule,
|
||||
)
|
||||
|
||||
|
||||
# Payload used by all tests
|
||||
def emit_payload():
|
||||
payload_module = ir.Module.create()
|
||||
with ir.InsertionPoint(payload_module.body):
|
||||
f32 = ir.F32Type.get()
|
||||
|
||||
@func.FuncOp.from_py_func(f32, f32, results=[f32])
|
||||
def name_of_func(a, b):
|
||||
c = arith.addf(a, b)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
arith.constant(i32, 42)
|
||||
arith.constant(i32, 24)
|
||||
func.ReturnOp([c])
|
||||
|
||||
return payload_module
|
||||
|
||||
|
||||
@contextmanager
|
||||
def schedule_boilerplate():
|
||||
schedule = ir.Module.create()
|
||||
schedule.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
|
||||
with ir.InsertionPoint(schedule.body):
|
||||
named_sequence = transform.NamedSequenceOp(
|
||||
"__transform_main",
|
||||
[AnyOpType.get()],
|
||||
[AnyOpType.get()],
|
||||
arg_attrs=[{"transform.consumed": ir.UnitAttr.get()}],
|
||||
)
|
||||
with ir.InsertionPoint(named_sequence.body):
|
||||
yield schedule, named_sequence
|
||||
|
||||
|
||||
# MemoryEffectsOpInterface implementation for TransformOpInterface-implementing ops.
|
||||
# Used by most ops defined below.
|
||||
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
|
||||
@staticmethod
|
||||
def get_effects(op: ir.Operation, effects):
|
||||
transform.only_reads_handle(op.op_operands, effects)
|
||||
transform.produces_handle(op.results, effects)
|
||||
transform.only_reads_payload(effects)
|
||||
|
||||
|
||||
# Demonstration of a TransformOpInterface-implementing op that gets named attributes
|
||||
# from target ops and produces them as param handles.
|
||||
@ext.register_operation(MyTransform)
|
||||
class GetNamedAttributeOp(MyTransform.Operation, name="get_named_attribute"):
|
||||
target: ext.Operand[transform.AnyOpType]
|
||||
attr_name: ir.StringAttr
|
||||
attr_as_param: ext.Result[transform.AnyParamType[()]]
|
||||
|
||||
@classmethod
|
||||
def attach_interface_impls(cls, ctx=None):
|
||||
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
|
||||
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: "GetNamedAttributeOp",
|
||||
_rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
target_ops = state.get_payload_ops(op.target)
|
||||
associated_attrs = []
|
||||
for target_op in target_ops:
|
||||
assoc_attr = target_op.attributes.get(op.attr_name.value)
|
||||
if assoc_attr is None:
|
||||
return DiagnosedSilenceableFailure.RecoverableFailure
|
||||
associated_attrs.append(assoc_attr)
|
||||
results.set_params(op.attr_as_param, associated_attrs)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@ext.register_operation(MyTransform)
|
||||
class PrintParamOp(MyTransform.Operation, name="print_param"):
|
||||
target: ext.Operand[transform.AnyParamType]
|
||||
name: ir.StringAttr
|
||||
|
||||
@classmethod
|
||||
def attach_interface_impls(cls, ctx=None):
|
||||
cls.TransformOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(cls.OPERATION_NAME, context=ctx)
|
||||
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: "PrintParamOp",
|
||||
rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
target_attrs = state.get_params(op.target)
|
||||
print(f"[[[ IR printer: {op.name.value} ]]]")
|
||||
for attr in target_attrs:
|
||||
print(attr)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: "GetNamedAttributeOp") -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Syntax for an op with one op handle operand and one op handle result.
|
||||
@ext.register_operation(MyTransform)
|
||||
class OneOpInOneOpOut(MyTransform.Operation, name="one_op_in_one_op_out"):
|
||||
target: ext.Operand[transform.AnyOpType]
|
||||
res: ext.Result[transform.AnyOpType[()]]
|
||||
|
||||
|
||||
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterface
|
||||
@run
|
||||
def OneOpInOneOpOutTransformOpInterface():
|
||||
"""Tests a simple passthrough interface implementation.
|
||||
|
||||
Checks that the target ops are correctly identified and passed as results.
|
||||
"""
|
||||
|
||||
# Define a simple passthrough implementation of the TransformOpInterface for OneOpInOneOpOut.
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: OneOpInOneOpOut,
|
||||
_rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
target_ops = state.get_payload_ops(op.target)
|
||||
target_names = [t.name.value for t in target_ops]
|
||||
print(f"OneOpInOneOpOutTransformOpInterface: target_names={target_names}")
|
||||
results.set_ops(op.res, target_ops)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
|
||||
return False
|
||||
|
||||
# Attach the interface implementation to the op.
|
||||
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
|
||||
|
||||
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
|
||||
|
||||
with schedule_boilerplate() as (schedule, named_seq):
|
||||
func_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["func.func"]
|
||||
).result
|
||||
# CHECK: OneOpInOneOpOutTransformOpInterface: target_names=['name_of_func']
|
||||
out = OneOpInOneOpOut(func_handle).result
|
||||
# CHECK: Output handle from OneOpInOneOpOut
|
||||
# CHECK-NEXT: func.func @name_of_func
|
||||
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut")
|
||||
transform.YieldOp([out])
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
# CHECK-LABEL: Test: OneOpInOneOpOutTransformOpInterfaceRewriterImpl
|
||||
@run
|
||||
def OneOpInOneOpOutTransformOpInterfaceRewriterImpl():
|
||||
"""Tests an interface implementation using the rewriter to modify the IR.
|
||||
|
||||
Checks that `arith.constant` ops are replaced by `index.constant` ops and
|
||||
that the results are correctly updated.
|
||||
"""
|
||||
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: OneOpInOneOpOut,
|
||||
rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
result_ops = []
|
||||
for target_op in state.get_payload_ops(op.target):
|
||||
with ir.InsertionPoint(target_op):
|
||||
index_version = index.constant(target_op.value.value)
|
||||
result_ops.append(index_version.owner)
|
||||
rewriter.replace_op(target_op, [index_version])
|
||||
results.set_ops(op.res, result_ops)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: OneOpInOneOpOut) -> bool:
|
||||
return False
|
||||
|
||||
# Attach the interface implementation to the op.
|
||||
TransformOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
|
||||
|
||||
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
|
||||
class MemoryEffectsOpInterfaceFallbackModel(ir.MemoryEffectsOpInterface):
|
||||
@staticmethod
|
||||
def get_effects(op: ir.Operation, effects):
|
||||
transform.consumes_handle(op.op_operands, effects)
|
||||
transform.produces_handle(op.results, effects)
|
||||
transform.modifies_payload(effects)
|
||||
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(OneOpInOneOpOut.OPERATION_NAME)
|
||||
|
||||
with schedule_boilerplate() as (schedule, named_seq):
|
||||
func_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["func.func"]
|
||||
).result
|
||||
csts_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["arith.constant"]
|
||||
).result
|
||||
# CHECK: Before replacement:
|
||||
# CHECK-NOT: index.constant
|
||||
# CHECK-DAG: arith.constant 42 : i32
|
||||
# CHECK-DAG: arith.constant 24 : i32
|
||||
transform.PrintOp(target=func_handle, name="Before replacement:")
|
||||
out = OneOpInOneOpOut(csts_handle).result
|
||||
# CHECK: After replacement:
|
||||
# CHECK-NOT: arith.constant
|
||||
# CHECK-DAG: index.constant 42
|
||||
# CHECK-DAG: index.constant 24
|
||||
transform.PrintOp(target=func_handle, name="After replacement:")
|
||||
# CHECK: Output handle from OneOpInOneOpOut:
|
||||
# CHECK-NEXT: index.constant 42
|
||||
# CHECK-NEXT: index.constant 24
|
||||
transform.PrintOp(target=out, name="Output handle from OneOpInOneOpOut:")
|
||||
transform.YieldOp([out])
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
@ext.register_operation(MyTransform)
|
||||
class OpValParamInParamOpValOut(
|
||||
MyTransform.Operation, name="op_val_param_in_param_op_val_out"
|
||||
):
|
||||
# operands
|
||||
op_arg: ext.Operand[transform.AnyOpType]
|
||||
val_arg: ext.Operand[transform.AnyValueType]
|
||||
param_arg: ext.Operand[transform.AnyParamType]
|
||||
# results
|
||||
param_res: ext.Result[transform.AnyParamType[()]]
|
||||
op_res: ext.Result[transform.AnyOpType[()]]
|
||||
value_res: ext.Result[transform.AnyValueType[()]]
|
||||
|
||||
|
||||
# CHECK-LABEL: Test: OpValParamInParamOpValOutTransformOpInterface
|
||||
@run
|
||||
def OpValParamInParamOpValOutTransformOpInterface():
|
||||
"""Tests an interface implementation involving Op, Value, and Param types.
|
||||
|
||||
Checks that payload ops, values, and parameters are correctly permuted and
|
||||
propagated and accessible from the (permuted) result handles.
|
||||
"""
|
||||
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: OpValParamInParamOpValOut,
|
||||
_rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
ops = state.get_payload_ops(op.op_arg)
|
||||
values = state.get_payload_values(op.val_arg)
|
||||
params = state.get_params(op.param_arg)
|
||||
print(
|
||||
f"OpValParamInParamOpValOutTransformOpInterface: ops={len(ops)}, values={len(values)}, params={len(params)}"
|
||||
)
|
||||
results.set_params(op.param_res, params)
|
||||
results.set_ops(op.op_res, ops)
|
||||
results.set_values(op.value_res, values)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: OpValParamInParamOpValOut) -> bool:
|
||||
return False
|
||||
|
||||
TransformOpInterfaceFallbackModel.attach(OpValParamInParamOpValOut.OPERATION_NAME)
|
||||
|
||||
# TransformOpInterface-implementing ops are also required to implement MemoryEffectsOpInterface. The above defined fallback model works for this op.
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(
|
||||
OpValParamInParamOpValOut.OPERATION_NAME
|
||||
)
|
||||
|
||||
with schedule_boilerplate() as (schedule, named_seq):
|
||||
func_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["func.func"]
|
||||
).result
|
||||
addf_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["arith.addf"]
|
||||
).result
|
||||
func_and_addf = transform.MergeHandlesOp([func_handle, addf_handle])
|
||||
value_handle = transform.GetResultOp(
|
||||
AnyValueType.get(), addf_handle, [0]
|
||||
).result
|
||||
param_handle = transform.ParamConstantOp(
|
||||
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
||||
).param
|
||||
|
||||
# CHECK: OpValParamInParamOpValOutTransformOpInterface: ops=2, values=1, params=1
|
||||
op_val_param_op = OpValParamInParamOpValOut(
|
||||
func_and_addf, value_handle, param_handle
|
||||
)
|
||||
# CHECK: Ops passed through OpValParamInParamOpValOut:
|
||||
# CHECK-NEXT: func.func
|
||||
# CHECK: arith.addf
|
||||
transform.PrintOp(
|
||||
target=op_val_param_op.op_res,
|
||||
name="Ops passed through OpValParamInParamOpValOut:",
|
||||
)
|
||||
|
||||
# CHECK: Ops defining values passed through OpValParamInParamOpValOut:
|
||||
# CHECK-NEXT: arith.addf
|
||||
addf_as_res = transform.GetDefiningOp(
|
||||
transform.AnyOpType.get(), op_val_param_op.value_res
|
||||
).result
|
||||
transform.PrintOp(
|
||||
target=addf_as_res,
|
||||
name="Ops defining values passed through OpValParamInParamOpValOut:",
|
||||
)
|
||||
|
||||
# CHECK: Parameter passed through OpValParamInParamOpValOut:
|
||||
# CHECK-NEXT: 42 : i32
|
||||
PrintParamOp(
|
||||
op_val_param_op.param_res,
|
||||
name=ir.StringAttr.get(
|
||||
"Parameter passed through OpValParamInParamOpValOut:"
|
||||
),
|
||||
)
|
||||
|
||||
transform.YieldOp([op_val_param_op.op_res])
|
||||
named_seq.verify()
|
||||
|
||||
return schedule
|
||||
|
||||
|
||||
@ext.register_operation(MyTransform)
|
||||
class OpsParamsInValuesParamOut(
|
||||
MyTransform.Operation, name="ops_params_in_values_param_out"
|
||||
):
|
||||
# operands
|
||||
ops: Sequence[ext.Operand[transform.AnyOpType]]
|
||||
params: Sequence[ext.Operand[transform.AnyParamType]]
|
||||
# results
|
||||
values: Sequence[ext.Result[transform.AnyValueType]]
|
||||
param: ext.Result[transform.AnyParamType]
|
||||
|
||||
|
||||
# CHECK-LABEL: Test: OpsParamsInValuesParamOutTransformOpInterface
|
||||
@run
|
||||
def OpsParamsInValuesParamOutTransformOpInterface():
|
||||
"""Tests an interface with variadic Op and Param operands and variadic Value results.
|
||||
|
||||
Checks correct handling of multiple handles, parameter aggregation, and
|
||||
result generation.
|
||||
"""
|
||||
|
||||
class TransformOpInterfaceFallbackModel(transform.TransformOpInterface):
|
||||
@staticmethod
|
||||
def apply(
|
||||
op: OpsParamsInValuesParamOut,
|
||||
_rewriter: transform.TransformRewriter,
|
||||
results: transform.TransformResults,
|
||||
state: transform.TransformState,
|
||||
) -> DiagnosedSilenceableFailure:
|
||||
ops_count = 0
|
||||
value_handles = []
|
||||
for op_handle in op.ops:
|
||||
ops = state.get_payload_ops(op_handle)
|
||||
ops_count += len(ops)
|
||||
value_handles.append([i for op in ops for i in op.results])
|
||||
|
||||
param_count = 0
|
||||
param_sum = 0
|
||||
for param_handle in op.params:
|
||||
params = state.get_params(param_handle)
|
||||
param_count += len(params)
|
||||
param_sum += sum(p.value for p in params)
|
||||
|
||||
print(
|
||||
f"OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count={ops_count}, param_count={param_count}"
|
||||
)
|
||||
|
||||
assert len(op.values) == len(op.ops)
|
||||
for value_res_handle, value_vector in zip(op.values, value_handles):
|
||||
results.set_values(value_res_handle, value_vector)
|
||||
results.set_params(
|
||||
op.param,
|
||||
[ir.IntegerAttr.get(ir.IntegerType.get_signless(32), param_sum)],
|
||||
)
|
||||
return DiagnosedSilenceableFailure.Success
|
||||
|
||||
@staticmethod
|
||||
def allow_repeated_handle_operands(_op: OpsParamsInValuesParamOut) -> bool:
|
||||
return False
|
||||
|
||||
TransformOpInterfaceFallbackModel.attach(OpsParamsInValuesParamOut.OPERATION_NAME)
|
||||
|
||||
MemoryEffectsOpInterfaceFallbackModel.attach(
|
||||
OpsParamsInValuesParamOut.OPERATION_NAME
|
||||
)
|
||||
|
||||
with schedule_boilerplate() as (schedule, named_seq):
|
||||
func_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["func.func"]
|
||||
).result
|
||||
csts_handle = structured.MatchOp.match_op_names(
|
||||
named_seq.bodyTarget, ["arith.constant"]
|
||||
).result
|
||||
csts_as_param = GetNamedAttributeOp(
|
||||
csts_handle, attr_name=ir.StringAttr.get("value")
|
||||
).attr_as_param
|
||||
|
||||
param_handle = transform.ParamConstantOp(
|
||||
AnyParamType.get(), ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 123)
|
||||
).param
|
||||
|
||||
# CHECK: OpsParamsInValuesParamOutTransformOpInterfaceFallbackModel: op_count=3, param_count=3
|
||||
op = OpsParamsInValuesParamOut(
|
||||
[transform.AnyValueType.get()] * 2,
|
||||
transform.AnyParamType.get(),
|
||||
[func_handle, csts_handle],
|
||||
[csts_as_param, param_handle],
|
||||
)
|
||||
|
||||
empty_handle = transform.GetDefiningOp(transform.AnyOpType.get(), op.values[0])
|
||||
# CHECK: Defining op of value result 0
|
||||
transform.PrintOp(
|
||||
target=empty_handle.result, name="Defining op of value result 0"
|
||||
)
|
||||
# NB: no result on the func.func, so output is expected to be empty
|
||||
cst1_res, cst2_res = transform.SplitHandleOp(
|
||||
[transform.AnyValueType.get()] * 2, op.values[1]
|
||||
).results
|
||||
|
||||
cst1_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst1_res)
|
||||
# CHECK-NEXT: Defining op of first constant
|
||||
# CHECK-NEXT: arith.constant 42 : i32
|
||||
transform.PrintOp(
|
||||
target=cst1_again.result, name="Defining op of first constant"
|
||||
)
|
||||
cst2_again = transform.GetDefiningOp(transform.AnyOpType.get(), cst2_res)
|
||||
# CHECK-NEXT: Defining op of second constant
|
||||
# CHECK-NEXT: arith.constant 24 : i32
|
||||
transform.PrintOp(
|
||||
target=cst2_again.result, name="Defining op of second constant"
|
||||
)
|
||||
|
||||
# CHECK: Sum of params:
|
||||
# CHECK-NEXT: 189 : i32
|
||||
PrintParamOp(op.param, name=ir.StringAttr.get("Sum of params:"))
|
||||
|
||||
transform.YieldOp([func_handle])
|
||||
named_seq.verify()
|
||||
|
||||
return schedule
|
||||
Reference in New Issue
Block a user