Using `Sequence` frees users from the need to cast to `list` in cases where the underlying API does not really care about the type of the container. Note that accepting an `nb::sequence` is marginally slower than accepting `nb::list` directly, because `__getitem__`, `__len__` etc need to go through an extra layer of indirection. However, I expect the performance difference to be negligible.
578 lines
22 KiB
C++
578 lines
22 KiB
C++
//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <string>
|
|
|
|
#include "Rewrite.h"
|
|
#include "mlir-c/Dialect/Transform.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir-c/Rewrite.h"
|
|
#include "mlir-c/Support.h"
|
|
#include "mlir/Bindings/Python/IRCore.h"
|
|
#include "mlir/Bindings/Python/IRInterfaces.h"
|
|
#include "nanobind/nanobind.h"
|
|
#include <nanobind/trampoline.h>
|
|
|
|
namespace nb = nanobind;
|
|
using namespace mlir::python::nanobind_adaptors;
|
|
|
|
namespace mlir {
|
|
namespace python {
|
|
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
|
|
namespace transform {
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformRewriter
|
|
//===----------------------------------------------------------------------===//
|
|
class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
|
|
public:
|
|
static constexpr const char *pyClassName = "TransformRewriter";
|
|
|
|
PyTransformRewriter(MlirTransformRewriter rewriter)
|
|
: PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformResults
|
|
//===----------------------------------------------------------------------===//
|
|
class PyTransformResults {
|
|
public:
|
|
PyTransformResults(MlirTransformResults results) : results(results) {}
|
|
|
|
MlirTransformResults get() const { return results; }
|
|
|
|
void setOps(PyValue &result,
|
|
const nb::typed<nb::sequence, PyOperationBase> &ops) {
|
|
std::vector<MlirOperation> opsVec;
|
|
opsVec.reserve(nb::len(ops));
|
|
for (auto op : ops) {
|
|
opsVec.push_back(nb::cast<MlirOperation>(op));
|
|
}
|
|
mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
|
|
}
|
|
|
|
void setValues(PyValue &result,
|
|
const nb::typed<nb::sequence, PyValue> &values) {
|
|
std::vector<MlirValue> valuesVec;
|
|
valuesVec.reserve(nb::len(values));
|
|
for (auto item : values) {
|
|
valuesVec.push_back(nb::cast<MlirValue>(item));
|
|
}
|
|
mlirTransformResultsSetValues(results, result, valuesVec.size(),
|
|
valuesVec.data());
|
|
}
|
|
|
|
void setParams(PyValue &result,
|
|
const nb::typed<nb::sequence, PyAttribute> ¶ms) {
|
|
std::vector<MlirAttribute> paramsVec;
|
|
paramsVec.reserve(nb::len(params));
|
|
for (auto item : params) {
|
|
paramsVec.push_back(nb::cast<MlirAttribute>(item));
|
|
}
|
|
mlirTransformResultsSetParams(results, result, paramsVec.size(),
|
|
paramsVec.data());
|
|
}
|
|
|
|
static void bind(nanobind::module_ &m) {
|
|
nb::class_<PyTransformResults>(m, "TransformResults")
|
|
.def(nb::init<MlirTransformResults>())
|
|
.def("set_ops", &PyTransformResults::setOps,
|
|
"Set the payload operations for a transform result.",
|
|
nb::arg("result"), nb::arg("ops"))
|
|
.def("set_values", &PyTransformResults::setValues,
|
|
"Set the payload values for a transform result.",
|
|
nb::arg("result"), nb::arg("values"))
|
|
.def("set_params", &PyTransformResults::setParams,
|
|
"Set the parameters for a transform result.", nb::arg("result"),
|
|
nb::arg("params"));
|
|
}
|
|
|
|
private:
|
|
MlirTransformResults results;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformState
|
|
//===----------------------------------------------------------------------===//
|
|
class PyTransformState {
|
|
public:
|
|
PyTransformState(MlirTransformState state) : state(state) {}
|
|
|
|
MlirTransformState get() const { return state; }
|
|
|
|
static void bind(nanobind::module_ &m) {
|
|
nb::class_<PyTransformState>(m, "TransformState")
|
|
.def(nb::init<MlirTransformState>())
|
|
.def("get_payload_ops", &PyTransformState::getPayloadOps,
|
|
"Get the payload operations associated with a transform IR value.",
|
|
nb::arg("operand"))
|
|
.def("get_payload_values", &PyTransformState::getPayloadValues,
|
|
"Get the payload values associated with a transform IR value.",
|
|
nb::arg("operand"))
|
|
.def("get_params", &PyTransformState::getParams,
|
|
"Get the parameters (attributes) associated with a transform IR "
|
|
"value.",
|
|
nb::arg("operand"));
|
|
}
|
|
|
|
private:
|
|
nanobind::list getPayloadOps(PyValue &value) {
|
|
nanobind::list result;
|
|
mlirTransformStateForEachPayloadOp(
|
|
state, value,
|
|
[](MlirOperation op, void *userData) {
|
|
PyMlirContextRef context =
|
|
PyMlirContext::forContext(mlirOperationGetContext(op));
|
|
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
|
static_cast<nanobind::list *>(userData)->append(opview);
|
|
},
|
|
&result);
|
|
return result;
|
|
}
|
|
|
|
nanobind::list getPayloadValues(PyValue &value) {
|
|
nanobind::list result;
|
|
mlirTransformStateForEachPayloadValue(
|
|
state, value,
|
|
[](MlirValue val, void *userData) {
|
|
static_cast<nanobind::list *>(userData)->append(val);
|
|
},
|
|
&result);
|
|
return result;
|
|
}
|
|
|
|
nanobind::list getParams(PyValue &value) {
|
|
nanobind::list result;
|
|
mlirTransformStateForEachParam(
|
|
state, value,
|
|
[](MlirAttribute attr, void *userData) {
|
|
static_cast<nanobind::list *>(userData)->append(attr);
|
|
},
|
|
&result);
|
|
return result;
|
|
}
|
|
|
|
MlirTransformState state;
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TransformOpInterface
|
|
//===----------------------------------------------------------------------===//
|
|
class PyTransformOpInterface
|
|
: public PyConcreteOpInterface<PyTransformOpInterface> {
|
|
public:
|
|
using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
|
|
|
|
constexpr static const char *pyClassName = "TransformOpInterface";
|
|
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
|
&mlirTransformOpInterfaceTypeID;
|
|
|
|
/// Attach a new TransformOpInterface FallbackModel to the named operation.
|
|
/// The FallbackModel acts as a trampoline for callbacks on the Python class.
|
|
static void attach(nb::object &target, const std::string &opName,
|
|
DefaultingPyMlirContext ctx) {
|
|
// Prepare the callbacks that will be used by the FallbackModel.
|
|
MlirTransformOpInterfaceCallbacks callbacks;
|
|
// Make the pointer to the Python class available to the callbacks.
|
|
callbacks.userData = target.ptr();
|
|
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
|
|
|
|
// The above ref bump is all we need as initialization, no need to run the
|
|
// construct callback.
|
|
callbacks.construct = nullptr;
|
|
// Upon the FallbackModel's destruction, drop the ref to the Python class.
|
|
callbacks.destruct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
|
};
|
|
// The apply callback which calls into Python.
|
|
callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
|
|
MlirTransformResults results, MlirTransformState state,
|
|
void *userData) -> MlirDiagnosedSilenceableFailure {
|
|
nb::handle pyClass(static_cast<PyObject *>(userData));
|
|
|
|
auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
|
|
|
|
auto pyRewriter = PyTransformRewriter(rewriter);
|
|
auto pyResults = PyTransformResults(results);
|
|
auto pyState = PyTransformState(state);
|
|
|
|
// Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
|
|
// staticmethod.
|
|
PyMlirContextRef context =
|
|
PyMlirContext::forContext(mlirOperationGetContext(op));
|
|
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
|
nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
|
|
|
|
return nb::cast<MlirDiagnosedSilenceableFailure>(res);
|
|
};
|
|
|
|
// The allows_repeated_handle_operands callback which calls into Python.
|
|
callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
|
|
void *userData) -> bool {
|
|
nb::handle pyClass(static_cast<PyObject *>(userData));
|
|
|
|
auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
|
|
nb::getattr(pyClass, "allow_repeated_handle_operands"));
|
|
|
|
// Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
|
|
// staticmethod.
|
|
PyMlirContextRef context =
|
|
PyMlirContext::forContext(mlirOperationGetContext(op));
|
|
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
|
nb::object res = pyAllowRepeatedHandleOperands(opview);
|
|
|
|
return nb::cast<bool>(res);
|
|
};
|
|
|
|
// Attach a FallbackModel, which calls into Python, to the named operation.
|
|
mlirTransformOpInterfaceAttachFallbackModel(
|
|
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
|
|
callbacks);
|
|
}
|
|
|
|
static void bindDerived(ClassTy &cls) {
|
|
cls.attr("attach") = classmethod(
|
|
[](const nb::object &cls, const nb::object &opName, nb::object target,
|
|
DefaultingPyMlirContext context) {
|
|
if (target.is_none())
|
|
target = cls;
|
|
return attach(target, nb::cast<std::string>(opName), context);
|
|
},
|
|
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
|
|
nb::arg("target").none() = nb::none(),
|
|
nb::arg("context").none() = nb::none(),
|
|
"Attach the interface subclass to the given operation name.");
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// PatternDescriptorOpInterface
|
|
//===----------------------------------------------------------------------===//
|
|
class PyPatternDescriptorOpInterface
|
|
: public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
|
|
public:
|
|
using PyConcreteOpInterface<
|
|
PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
|
|
|
|
constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
|
|
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
|
&mlirPatternDescriptorOpInterfaceTypeID;
|
|
|
|
/// Attach a new PatternDescriptorOpInterface FallbackModel to the named
|
|
/// operation. The FallbackModel acts as a trampoline for callbacks on the
|
|
/// Python class.
|
|
static void attach(nb::object &target, const std::string &opName,
|
|
DefaultingPyMlirContext ctx) {
|
|
// Prepare the callbacks that will be used by the FallbackModel.
|
|
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
|
|
// Make the pointer to the Python class available to the callbacks.
|
|
callbacks.userData = target.ptr();
|
|
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
|
|
|
|
// The above ref bump is all we need as initialization, no need to run the
|
|
// construct callback.
|
|
callbacks.construct = nullptr;
|
|
// Upon the FallbackModel's destruction, drop the ref to the Python class.
|
|
callbacks.destruct = [](void *userData) {
|
|
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
|
|
};
|
|
|
|
// The populatePatterns callback which calls into Python.
|
|
callbacks.populatePatterns =
|
|
[](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
|
|
nb::handle pyClass(static_cast<PyObject *>(userData));
|
|
|
|
auto pyPopulatePatterns =
|
|
nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
|
|
|
|
auto pyPatterns = PyRewritePatternSet(patterns);
|
|
|
|
// Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
|
|
// staticmethod.
|
|
MlirContext ctx = mlirOperationGetContext(op);
|
|
PyMlirContextRef context = PyMlirContext::forContext(ctx);
|
|
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
|
pyPopulatePatterns(opview, pyPatterns);
|
|
};
|
|
|
|
// The populatePatternsWithState callback which calls into Python.
|
|
// Check if the Python class has populate_patterns_with_state method.
|
|
if (nb::hasattr(target, "populate_patterns_with_state")) {
|
|
callbacks.populatePatternsWithState = [](MlirOperation op,
|
|
MlirRewritePatternSet patterns,
|
|
MlirTransformState state,
|
|
void *userData) {
|
|
nb::handle pyClass(static_cast<PyObject *>(userData));
|
|
|
|
auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
|
|
nb::getattr(pyClass, "populate_patterns_with_state"));
|
|
|
|
auto pyPatterns = PyRewritePatternSet(patterns);
|
|
auto pyState = PyTransformState(state);
|
|
|
|
// Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
|
|
// state)` as a staticmethod.
|
|
MlirContext ctx = mlirOperationGetContext(op);
|
|
PyMlirContextRef context = PyMlirContext::forContext(ctx);
|
|
auto opview = PyOperation::forOperation(context, op)->createOpView();
|
|
pyPopulatePatternsWithState(opview, pyPatterns, pyState);
|
|
};
|
|
} else {
|
|
// Use default implementation (will call populatePatterns).
|
|
callbacks.populatePatternsWithState = nullptr;
|
|
}
|
|
|
|
// Attach a FallbackModel, which calls into Python, to the named operation.
|
|
mlirPatternDescriptorOpInterfaceAttachFallbackModel(
|
|
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
|
|
callbacks);
|
|
}
|
|
|
|
static void bindDerived(ClassTy &cls) {
|
|
cls.attr("attach") = classmethod(
|
|
[](const nb::object &cls, const nb::object &opName, nb::object target,
|
|
DefaultingPyMlirContext context) {
|
|
if (target.is_none())
|
|
target = cls;
|
|
return attach(target, nb::cast<std::string>(opName), context);
|
|
},
|
|
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
|
|
nb::arg("target").none() = nb::none(),
|
|
nb::arg("context").none() = nb::none(),
|
|
"Attach the interface subclass to the given operation name.");
|
|
}
|
|
};
|
|
|
|
//===-------------------------------------------------------------------===//
|
|
// AnyOpType
|
|
//===-------------------------------------------------------------------===//
|
|
|
|
struct AnyOpType : PyConcreteType<AnyOpType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirTransformAnyOpTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "AnyOpType";
|
|
static inline const MlirStringRef name = mlirTransformAnyOpTypeGetName();
|
|
using Base::Base;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](DefaultingPyMlirContext context) {
|
|
return AnyOpType(context->getRef(),
|
|
mlirTransformAnyOpTypeGet(context.get()->get()));
|
|
},
|
|
"Get an instance of AnyOpType in the given context.",
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
//===-------------------------------------------------------------------===//
|
|
// AnyParamType
|
|
//===-------------------------------------------------------------------===//
|
|
|
|
struct AnyParamType : PyConcreteType<AnyParamType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirTransformAnyParamTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "AnyParamType";
|
|
static inline const MlirStringRef name = mlirTransformAnyParamTypeGetName();
|
|
using Base::Base;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](DefaultingPyMlirContext context) {
|
|
return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
|
|
context.get()->get()));
|
|
},
|
|
"Get an instance of AnyParamType in the given context.",
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
//===-------------------------------------------------------------------===//
|
|
// AnyValueType
|
|
//===-------------------------------------------------------------------===//
|
|
|
|
struct AnyValueType : PyConcreteType<AnyValueType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirTransformAnyValueTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "AnyValueType";
|
|
static inline const MlirStringRef name = mlirTransformAnyValueTypeGetName();
|
|
using Base::Base;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](DefaultingPyMlirContext context) {
|
|
return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
|
|
context.get()->get()));
|
|
},
|
|
"Get an instance of AnyValueType in the given context.",
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
//===-------------------------------------------------------------------===//
|
|
// OperationType
|
|
//===-------------------------------------------------------------------===//
|
|
|
|
struct OperationType : PyConcreteType<OperationType> {
|
|
static constexpr IsAFunctionTy isaFunction =
|
|
mlirTypeIsATransformOperationType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirTransformOperationTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "OperationType";
|
|
static inline const MlirStringRef name = mlirTransformOperationTypeGetName();
|
|
using Base::Base;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](const std::string &operationName, DefaultingPyMlirContext context) {
|
|
MlirStringRef cOperationName =
|
|
mlirStringRefCreate(operationName.data(), operationName.size());
|
|
return OperationType(context->getRef(),
|
|
mlirTransformOperationTypeGet(
|
|
context.get()->get(), cOperationName));
|
|
},
|
|
"Get an instance of OperationType for the given kind in the given "
|
|
"context",
|
|
nb::arg("operation_name"), nb::arg("context").none() = nb::none());
|
|
c.def_prop_ro(
|
|
"operation_name",
|
|
[](const OperationType &type) {
|
|
MlirStringRef operationName =
|
|
mlirTransformOperationTypeGetOperationName(type);
|
|
return nb::str(operationName.data, operationName.length);
|
|
},
|
|
"Get the name of the payload operation accepted by the handle.");
|
|
}
|
|
};
|
|
|
|
//===-------------------------------------------------------------------===//
|
|
// ParamType
|
|
//===-------------------------------------------------------------------===//
|
|
|
|
struct ParamType : PyConcreteType<ParamType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirTransformParamTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "ParamType";
|
|
static inline const MlirStringRef name = mlirTransformParamTypeGetName();
|
|
using Base::Base;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](const PyType &type, DefaultingPyMlirContext context) {
|
|
return ParamType(context->getRef(), mlirTransformParamTypeGet(
|
|
context.get()->get(), type));
|
|
},
|
|
"Get an instance of ParamType for the given type in the given context.",
|
|
nb::arg("type"), nb::arg("context").none() = nb::none());
|
|
c.def_prop_ro(
|
|
"type",
|
|
[](ParamType type) {
|
|
return PyType(type.getContext(), mlirTransformParamTypeGetType(type))
|
|
.maybeDownCast();
|
|
},
|
|
"Get the type this ParamType is associated with.");
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MemoryEffectsOpInterface helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
void onlyReadsHandle(nb::iterable &operands,
|
|
PyMemoryEffectsInstanceList effects) {
|
|
std::vector<MlirOpOperand> operandsVec;
|
|
for (auto operand : operands)
|
|
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
|
|
mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
|
|
effects.effects);
|
|
};
|
|
|
|
void consumesHandle(nb::iterable &operands,
|
|
PyMemoryEffectsInstanceList effects) {
|
|
std::vector<MlirOpOperand> operandsVec;
|
|
for (auto operand : operands)
|
|
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
|
|
mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
|
|
effects.effects);
|
|
};
|
|
|
|
void producesHandle(nb::iterable &results,
|
|
PyMemoryEffectsInstanceList effects) {
|
|
std::vector<MlirValue> resultsVec;
|
|
for (auto result : results)
|
|
resultsVec.push_back(nb::cast<PyOpResult>(result).get());
|
|
mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
|
|
effects.effects);
|
|
};
|
|
|
|
void modifiesPayload(PyMemoryEffectsInstanceList effects) {
|
|
mlirTransformModifiesPayload(effects.effects);
|
|
}
|
|
|
|
void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
|
|
mlirTransformOnlyReadsPayload(effects.effects);
|
|
}
|
|
} // namespace
|
|
|
|
static void populateDialectTransformSubmodule(nb::module_ &m) {
|
|
nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
|
|
.value("Success", MlirDiagnosedSilenceableFailureSuccess)
|
|
.value("SilenceableFailure",
|
|
MlirDiagnosedSilenceableFailureSilenceableFailure)
|
|
.value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
|
|
|
|
AnyOpType::bind(m);
|
|
AnyParamType::bind(m);
|
|
AnyValueType::bind(m);
|
|
OperationType::bind(m);
|
|
ParamType::bind(m);
|
|
|
|
PyTransformRewriter::bind(m);
|
|
PyTransformResults::bind(m);
|
|
PyTransformState::bind(m);
|
|
PyTransformOpInterface::bind(m);
|
|
PyPatternDescriptorOpInterface::bind(m);
|
|
|
|
m.def("only_reads_handle", onlyReadsHandle,
|
|
"Mark operands as only reading handles.", nb::arg("operands"),
|
|
nb::arg("effects"));
|
|
|
|
m.def("consumes_handle", consumesHandle,
|
|
"Mark operands as consuming handles.", nb::arg("operands"),
|
|
nb::arg("effects"));
|
|
|
|
m.def("produces_handle", producesHandle, "Mark results as producing handles.",
|
|
nb::arg("results"), nb::arg("effects"));
|
|
|
|
m.def("modifies_payload", modifiesPayload,
|
|
"Mark the transform as modifying the payload.", nb::arg("effects"));
|
|
|
|
m.def("only_reads_payload", onlyReadsPayload,
|
|
"Mark the transform as only reading the payload.", nb::arg("effects"));
|
|
}
|
|
} // namespace transform
|
|
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
|
|
} // namespace python
|
|
} // namespace mlir
|
|
|
|
NB_MODULE(_mlirDialectsTransform, m) {
|
|
m.doc() = "MLIR Transform dialect.";
|
|
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::transform::
|
|
populateDialectTransformSubmodule(m);
|
|
}
|