Files
llvm-project/mlir/lib/Bindings/Python/DialectTransform.cpp
Sergei Lebedev 2c0e63c9b6 [MLIR] [Python] Relaxed list to Sequence in most parameter types (#188543)
Using `Sequence` frees users from the need to cast to `list` in cases
where the underlying API does not really care about the type of the
container.

Note that accepting an `nb::sequence` is marginally slower than
accepting `nb::list` directly, because `__getitem__`, `__len__` etc need
to go through an extra layer of indirection. However, I expect the
performance difference to be negligible.
2026-03-25 19:15:38 +00:00

578 lines
22 KiB
C++

//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <string>
#include "Rewrite.h"
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRInterfaces.h"
#include "nanobind/nanobind.h"
#include <nanobind/trampoline.h>
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
namespace transform {
//===----------------------------------------------------------------------===//
// TransformRewriter
//===----------------------------------------------------------------------===//
class PyTransformRewriter : public PyRewriterBase<PyTransformRewriter> {
public:
static constexpr const char *pyClassName = "TransformRewriter";
PyTransformRewriter(MlirTransformRewriter rewriter)
: PyRewriterBase(mlirTransformRewriterAsBase(rewriter)) {}
};
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
class PyTransformResults {
public:
PyTransformResults(MlirTransformResults results) : results(results) {}
MlirTransformResults get() const { return results; }
void setOps(PyValue &result,
const nb::typed<nb::sequence, PyOperationBase> &ops) {
std::vector<MlirOperation> opsVec;
opsVec.reserve(nb::len(ops));
for (auto op : ops) {
opsVec.push_back(nb::cast<MlirOperation>(op));
}
mlirTransformResultsSetOps(results, result, opsVec.size(), opsVec.data());
}
void setValues(PyValue &result,
const nb::typed<nb::sequence, PyValue> &values) {
std::vector<MlirValue> valuesVec;
valuesVec.reserve(nb::len(values));
for (auto item : values) {
valuesVec.push_back(nb::cast<MlirValue>(item));
}
mlirTransformResultsSetValues(results, result, valuesVec.size(),
valuesVec.data());
}
void setParams(PyValue &result,
const nb::typed<nb::sequence, PyAttribute> &params) {
std::vector<MlirAttribute> paramsVec;
paramsVec.reserve(nb::len(params));
for (auto item : params) {
paramsVec.push_back(nb::cast<MlirAttribute>(item));
}
mlirTransformResultsSetParams(results, result, paramsVec.size(),
paramsVec.data());
}
static void bind(nanobind::module_ &m) {
nb::class_<PyTransformResults>(m, "TransformResults")
.def(nb::init<MlirTransformResults>())
.def("set_ops", &PyTransformResults::setOps,
"Set the payload operations for a transform result.",
nb::arg("result"), nb::arg("ops"))
.def("set_values", &PyTransformResults::setValues,
"Set the payload values for a transform result.",
nb::arg("result"), nb::arg("values"))
.def("set_params", &PyTransformResults::setParams,
"Set the parameters for a transform result.", nb::arg("result"),
nb::arg("params"));
}
private:
MlirTransformResults results;
};
//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
class PyTransformState {
public:
PyTransformState(MlirTransformState state) : state(state) {}
MlirTransformState get() const { return state; }
static void bind(nanobind::module_ &m) {
nb::class_<PyTransformState>(m, "TransformState")
.def(nb::init<MlirTransformState>())
.def("get_payload_ops", &PyTransformState::getPayloadOps,
"Get the payload operations associated with a transform IR value.",
nb::arg("operand"))
.def("get_payload_values", &PyTransformState::getPayloadValues,
"Get the payload values associated with a transform IR value.",
nb::arg("operand"))
.def("get_params", &PyTransformState::getParams,
"Get the parameters (attributes) associated with a transform IR "
"value.",
nb::arg("operand"));
}
private:
nanobind::list getPayloadOps(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachPayloadOp(
state, value,
[](MlirOperation op, void *userData) {
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
static_cast<nanobind::list *>(userData)->append(opview);
},
&result);
return result;
}
nanobind::list getPayloadValues(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachPayloadValue(
state, value,
[](MlirValue val, void *userData) {
static_cast<nanobind::list *>(userData)->append(val);
},
&result);
return result;
}
nanobind::list getParams(PyValue &value) {
nanobind::list result;
mlirTransformStateForEachParam(
state, value,
[](MlirAttribute attr, void *userData) {
static_cast<nanobind::list *>(userData)->append(attr);
},
&result);
return result;
}
MlirTransformState state;
};
//===----------------------------------------------------------------------===//
// TransformOpInterface
//===----------------------------------------------------------------------===//
class PyTransformOpInterface
: public PyConcreteOpInterface<PyTransformOpInterface> {
public:
using PyConcreteOpInterface<PyTransformOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "TransformOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirTransformOpInterfaceTypeID;
/// Attach a new TransformOpInterface FallbackModel to the named operation.
/// The FallbackModel acts as a trampoline for callbacks on the Python class.
static void attach(nb::object &target, const std::string &opName,
DefaultingPyMlirContext ctx) {
// Prepare the callbacks that will be used by the FallbackModel.
MlirTransformOpInterfaceCallbacks callbacks;
// Make the pointer to the Python class available to the callbacks.
callbacks.userData = target.ptr();
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
// The above ref bump is all we need as initialization, no need to run the
// construct callback.
callbacks.construct = nullptr;
// Upon the FallbackModel's destruction, drop the ref to the Python class.
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
// The apply callback which calls into Python.
callbacks.apply = [](MlirOperation op, MlirTransformRewriter rewriter,
MlirTransformResults results, MlirTransformState state,
void *userData) -> MlirDiagnosedSilenceableFailure {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyApply = nb::cast<nb::callable>(nb::getattr(pyClass, "apply"));
auto pyRewriter = PyTransformRewriter(rewriter);
auto pyResults = PyTransformResults(results);
auto pyState = PyTransformState(state);
// Invoke `pyClass.apply(opview(op), rewriter, results, state)` as a
// staticmethod.
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
nb::object res = pyApply(opview, pyRewriter, pyResults, pyState);
return nb::cast<MlirDiagnosedSilenceableFailure>(res);
};
// The allows_repeated_handle_operands callback which calls into Python.
callbacks.allowsRepeatedHandleOperands = [](MlirOperation op,
void *userData) -> bool {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyAllowRepeatedHandleOperands = nb::cast<nb::callable>(
nb::getattr(pyClass, "allow_repeated_handle_operands"));
// Invoke `pyClass.allow_repeated_handle_operands(opview(op))` as a
// staticmethod.
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(op));
auto opview = PyOperation::forOperation(context, op)->createOpView();
nb::object res = pyAllowRepeatedHandleOperands(opview);
return nb::cast<bool>(res);
};
// Attach a FallbackModel, which calls into Python, to the named operation.
mlirTransformOpInterfaceAttachFallbackModel(
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
callbacks);
}
static void bindDerived(ClassTy &cls) {
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName, nb::object target,
DefaultingPyMlirContext context) {
if (target.is_none())
target = cls;
return attach(target, nb::cast<std::string>(opName), context);
},
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
nb::arg("target").none() = nb::none(),
nb::arg("context").none() = nb::none(),
"Attach the interface subclass to the given operation name.");
}
};
//===----------------------------------------------------------------------===//
// PatternDescriptorOpInterface
//===----------------------------------------------------------------------===//
class PyPatternDescriptorOpInterface
: public PyConcreteOpInterface<PyPatternDescriptorOpInterface> {
public:
using PyConcreteOpInterface<
PyPatternDescriptorOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "PatternDescriptorOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirPatternDescriptorOpInterfaceTypeID;
/// Attach a new PatternDescriptorOpInterface FallbackModel to the named
/// operation. The FallbackModel acts as a trampoline for callbacks on the
/// Python class.
static void attach(nb::object &target, const std::string &opName,
DefaultingPyMlirContext ctx) {
// Prepare the callbacks that will be used by the FallbackModel.
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
// Make the pointer to the Python class available to the callbacks.
callbacks.userData = target.ptr();
nb::handle(static_cast<PyObject *>(callbacks.userData)).inc_ref();
// The above ref bump is all we need as initialization, no need to run the
// construct callback.
callbacks.construct = nullptr;
// Upon the FallbackModel's destruction, drop the ref to the Python class.
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
// The populatePatterns callback which calls into Python.
callbacks.populatePatterns =
[](MlirOperation op, MlirRewritePatternSet patterns, void *userData) {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyPopulatePatterns =
nb::cast<nb::callable>(nb::getattr(pyClass, "populate_patterns"));
auto pyPatterns = PyRewritePatternSet(patterns);
// Invoke `pyClass.populate_patterns(opview(op), patterns)` as a
// staticmethod.
MlirContext ctx = mlirOperationGetContext(op);
PyMlirContextRef context = PyMlirContext::forContext(ctx);
auto opview = PyOperation::forOperation(context, op)->createOpView();
pyPopulatePatterns(opview, pyPatterns);
};
// The populatePatternsWithState callback which calls into Python.
// Check if the Python class has populate_patterns_with_state method.
if (nb::hasattr(target, "populate_patterns_with_state")) {
callbacks.populatePatternsWithState = [](MlirOperation op,
MlirRewritePatternSet patterns,
MlirTransformState state,
void *userData) {
nb::handle pyClass(static_cast<PyObject *>(userData));
auto pyPopulatePatternsWithState = nb::cast<nb::callable>(
nb::getattr(pyClass, "populate_patterns_with_state"));
auto pyPatterns = PyRewritePatternSet(patterns);
auto pyState = PyTransformState(state);
// Invoke `pyClass.populate_patterns_with_state(opview(op), patterns,
// state)` as a staticmethod.
MlirContext ctx = mlirOperationGetContext(op);
PyMlirContextRef context = PyMlirContext::forContext(ctx);
auto opview = PyOperation::forOperation(context, op)->createOpView();
pyPopulatePatternsWithState(opview, pyPatterns, pyState);
};
} else {
// Use default implementation (will call populatePatterns).
callbacks.populatePatternsWithState = nullptr;
}
// Attach a FallbackModel, which calls into Python, to the named operation.
mlirPatternDescriptorOpInterfaceAttachFallbackModel(
ctx->get(), mlirStringRefCreate(opName.c_str(), opName.size()),
callbacks);
}
static void bindDerived(ClassTy &cls) {
cls.attr("attach") = classmethod(
[](const nb::object &cls, const nb::object &opName, nb::object target,
DefaultingPyMlirContext context) {
if (target.is_none())
target = cls;
return attach(target, nb::cast<std::string>(opName), context);
},
nb::arg("cls"), nb::arg("op_name"), nb::kw_only(),
nb::arg("target").none() = nb::none(),
nb::arg("context").none() = nb::none(),
"Attach the interface subclass to the given operation name.");
}
};
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
struct AnyOpType : PyConcreteType<AnyOpType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyOpType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyOpTypeGetTypeID;
static constexpr const char *pyClassName = "AnyOpType";
static inline const MlirStringRef name = mlirTransformAnyOpTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
return AnyOpType(context->getRef(),
mlirTransformAnyOpTypeGet(context.get()->get()));
},
"Get an instance of AnyOpType in the given context.",
nb::arg("context").none() = nb::none());
}
};
//===-------------------------------------------------------------------===//
// AnyParamType
//===-------------------------------------------------------------------===//
struct AnyParamType : PyConcreteType<AnyParamType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyParamType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyParamTypeGetTypeID;
static constexpr const char *pyClassName = "AnyParamType";
static inline const MlirStringRef name = mlirTransformAnyParamTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
return AnyParamType(context->getRef(), mlirTransformAnyParamTypeGet(
context.get()->get()));
},
"Get an instance of AnyParamType in the given context.",
nb::arg("context").none() = nb::none());
}
};
//===-------------------------------------------------------------------===//
// AnyValueType
//===-------------------------------------------------------------------===//
struct AnyValueType : PyConcreteType<AnyValueType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformAnyValueType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyValueTypeGetTypeID;
static constexpr const char *pyClassName = "AnyValueType";
static inline const MlirStringRef name = mlirTransformAnyValueTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
return AnyValueType(context->getRef(), mlirTransformAnyValueTypeGet(
context.get()->get()));
},
"Get an instance of AnyValueType in the given context.",
nb::arg("context").none() = nb::none());
}
};
//===-------------------------------------------------------------------===//
// OperationType
//===-------------------------------------------------------------------===//
struct OperationType : PyConcreteType<OperationType> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsATransformOperationType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformOperationTypeGetTypeID;
static constexpr const char *pyClassName = "OperationType";
static inline const MlirStringRef name = mlirTransformOperationTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const std::string &operationName, DefaultingPyMlirContext context) {
MlirStringRef cOperationName =
mlirStringRefCreate(operationName.data(), operationName.size());
return OperationType(context->getRef(),
mlirTransformOperationTypeGet(
context.get()->get(), cOperationName));
},
"Get an instance of OperationType for the given kind in the given "
"context",
nb::arg("operation_name"), nb::arg("context").none() = nb::none());
c.def_prop_ro(
"operation_name",
[](const OperationType &type) {
MlirStringRef operationName =
mlirTransformOperationTypeGetOperationName(type);
return nb::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
}
};
//===-------------------------------------------------------------------===//
// ParamType
//===-------------------------------------------------------------------===//
struct ParamType : PyConcreteType<ParamType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATransformParamType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformParamTypeGetTypeID;
static constexpr const char *pyClassName = "ParamType";
static inline const MlirStringRef name = mlirTransformParamTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](const PyType &type, DefaultingPyMlirContext context) {
return ParamType(context->getRef(), mlirTransformParamTypeGet(
context.get()->get(), type));
},
"Get an instance of ParamType for the given type in the given context.",
nb::arg("type"), nb::arg("context").none() = nb::none());
c.def_prop_ro(
"type",
[](ParamType type) {
return PyType(type.getContext(), mlirTransformParamTypeGetType(type))
.maybeDownCast();
},
"Get the type this ParamType is associated with.");
}
};
//===----------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===----------------------------------------------------------------------===//
namespace {
void onlyReadsHandle(nb::iterable &operands,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
mlirTransformOnlyReadsHandle(operandsVec.data(), operandsVec.size(),
effects.effects);
};
void consumesHandle(nb::iterable &operands,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirOpOperand> operandsVec;
for (auto operand : operands)
operandsVec.push_back(nb::cast<PyOpOperand>(operand));
mlirTransformConsumesHandle(operandsVec.data(), operandsVec.size(),
effects.effects);
};
void producesHandle(nb::iterable &results,
PyMemoryEffectsInstanceList effects) {
std::vector<MlirValue> resultsVec;
for (auto result : results)
resultsVec.push_back(nb::cast<PyOpResult>(result).get());
mlirTransformProducesHandle(resultsVec.data(), resultsVec.size(),
effects.effects);
};
void modifiesPayload(PyMemoryEffectsInstanceList effects) {
mlirTransformModifiesPayload(effects.effects);
}
void onlyReadsPayload(PyMemoryEffectsInstanceList effects) {
mlirTransformOnlyReadsPayload(effects.effects);
}
} // namespace
static void populateDialectTransformSubmodule(nb::module_ &m) {
nb::enum_<MlirDiagnosedSilenceableFailure>(m, "DiagnosedSilenceableFailure")
.value("Success", MlirDiagnosedSilenceableFailureSuccess)
.value("SilenceableFailure",
MlirDiagnosedSilenceableFailureSilenceableFailure)
.value("DefiniteFailure", MlirDiagnosedSilenceableFailureDefiniteFailure);
AnyOpType::bind(m);
AnyParamType::bind(m);
AnyValueType::bind(m);
OperationType::bind(m);
ParamType::bind(m);
PyTransformRewriter::bind(m);
PyTransformResults::bind(m);
PyTransformState::bind(m);
PyTransformOpInterface::bind(m);
PyPatternDescriptorOpInterface::bind(m);
m.def("only_reads_handle", onlyReadsHandle,
"Mark operands as only reading handles.", nb::arg("operands"),
nb::arg("effects"));
m.def("consumes_handle", consumesHandle,
"Mark operands as consuming handles.", nb::arg("operands"),
nb::arg("effects"));
m.def("produces_handle", producesHandle, "Mark results as producing handles.",
nb::arg("results"), nb::arg("effects"));
m.def("modifies_payload", modifiesPayload,
"Mark the transform as modifying the payload.", nb::arg("effects"));
m.def("only_reads_payload", onlyReadsPayload,
"Mark the transform as only reading the payload.", nb::arg("effects"));
}
} // namespace transform
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir
NB_MODULE(_mlirDialectsTransform, m) {
m.doc() = "MLIR Transform dialect.";
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::transform::
populateDialectTransformSubmodule(m);
}