Files
llvm-project/mlir/lib/CAPI/Dialect/Transform.cpp
Rolf Morel a1d7cda1d7 [MLIR][Python] Impl XOpInterface(s) from Python, with X=Transform and X=MemoryEffects (#176920)
Provides the infrastructure for implementing and late-binding
OpInterfaces from Python.

* On the mlir-c API declaration side, each `XOpInterface` has a callback
struct, with a callback for each method and a userdata member (provided
as an arg to each method), and a
`mlirXOpInterfaceAttachFallbackModel(ctx, op_name, callbacks)` func.
* This CAPI is implemented by defining a subclass of
`XOpInterface::FallbackModel` that holds the callback struct and has
each method call the corresponding callback (with userdata as an arg).
Given a callback struct, a new `FallbackModel` is created and attached,
i.e. late bound, to the named op. (MLIR's interface infrastructure is
such that the thus registered `FallbackModel` will be returned in case
the op gets cast to the `XOpInterface`.)
* On the Python side, we expose a stand-in `XOpInterface` base class
which has one (class)method: `XOpInterface.attach(cls, op_name, ctx)`.
Python users subclass this class (`class MyInterfaceImpl(XOpInterface):
...`) and implement the interface's methods (with the right names and
signatures). The user calls `attach` on the subclass
(`MyInterfaceImpl.attach("my_dialect.my_op", ctx)`) which prepares the
callbacks struct _with userdata set to the subclass_ (as we use it to
lookup methods). These callbacks (and userdata) are then registered as
an `XOpInterface::FallbackModel` by
`mlirXOpInterfaceAttachFallbackModel(...)`. From then on the Python
methods will be used to respond to calls to the interface methods
(originating in C++).

This PR enables implementing the TransformOpInterface and the
MemoryEffectsOpInterface, both of which are required for making an op
into a transform op.

Everything besides the above linked code is there to facilitate exposing
the interfaces: the right types for the arguments of the methods are
exposed as are functions/methods for manipulating these arguments (e.g.
specifying side effects on `OpOperand`s and `OpResult`s and being able
to access and set the transform handles associated with args and
results).
2026-02-12 14:07:10 +00:00

341 lines
13 KiB
C++

//===- Transform.cpp - C Interface for Transform dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Dialect/Transform.h"
#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Rewrite.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
transform::TransformDialect)
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}
MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
return wrap(transform::AnyOpType::getTypeID());
}
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyOpTypeGetName(void) {
return wrap(transform::AnyOpType::name);
}
//===---------------------------------------------------------------------===//
// AnyParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}
MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
return wrap(transform::AnyParamType::getTypeID());
}
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyParamTypeGetName(void) {
return wrap(transform::AnyParamType::name);
}
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}
MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
return wrap(transform::AnyValueType::getTypeID());
}
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyValueTypeGetName(void) {
return wrap(transform::AnyValueType::name);
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformOperationType(MlirType type) {
return isa<transform::OperationType>(unwrap(type));
}
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
return wrap(transform::OperationType::getTypeID());
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}
MlirStringRef mlirTransformOperationTypeGetName(void) {
return wrap(transform::OperationType::name);
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
//===---------------------------------------------------------------------===//
// ParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}
MlirTypeID mlirTransformParamTypeGetTypeID(void) {
return wrap(transform::ParamType::getTypeID());
}
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
MlirStringRef mlirTransformParamTypeGetName(void) {
return wrap(transform::ParamType::name);
}
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
//===---------------------------------------------------------------------===//
// TransformRewriter
//===---------------------------------------------------------------------===//
/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
mlir::transform::TransformRewriter *t = unwrap(rewriter);
mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
return wrap(base);
}
//===---------------------------------------------------------------------===//
// TransformResults
//===---------------------------------------------------------------------===//
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
intptr_t numOps, MlirOperation *ops) {
SmallVector<Operation *> opsVec;
opsVec.reserve(numOps);
for (intptr_t i = 0; i < numOps; ++i)
opsVec.push_back(unwrap(ops[i]));
unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
}
void mlirTransformResultsSetValues(MlirTransformResults results,
MlirValue result, intptr_t numValues,
MlirValue *values) {
SmallVector<Value> valuesVec;
valuesVec.reserve(numValues);
for (intptr_t i = 0; i < numValues; ++i)
valuesVec.push_back(unwrap(values[i]));
unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
}
void mlirTransformResultsSetParams(MlirTransformResults results,
MlirValue result, intptr_t numParams,
MlirAttribute *params) {
SmallVector<Attribute> paramsVec;
paramsVec.reserve(numParams);
for (intptr_t i = 0; i < numParams; ++i)
paramsVec.push_back(unwrap(params[i]));
unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
}
//===---------------------------------------------------------------------===//
// TransformState
//===---------------------------------------------------------------------===//
void mlirTransformStateForEachPayloadOp(MlirTransformState state,
MlirValue value,
MlirOperationCallback callback,
void *userData) {
for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
callback(wrap(op), userData);
}
void mlirTransformStateForEachPayloadValue(MlirTransformState state,
MlirValue value,
MlirValueCallback callback,
void *userData) {
for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
callback(wrap(val), userData);
}
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
MlirAttributeCallback callback,
void *userData) {
for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
callback(wrap(attr), userData);
}
//===---------------------------------------------------------------------===//
// TransformOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirTransformOpInterfaceTypeID(void) {
return wrap(transform::TransformOpInterface::getInterfaceID());
}
/// Fallback model for the TransformOpInterface that uses C API callbacks.
class TransformOpInterfaceFallbackModel
: public mlir::transform::TransformOpInterface::FallbackModel<
TransformOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~TransformOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return transform::TransformOpInterface::getInterfaceID();
}
static bool classof(const mlir::transform::detail::
TransformOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
::mlir::DiagnosedSilenceableFailure
apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state) const {
assert(callbacks.apply && "apply callback not set");
MlirDiagnosedSilenceableFailure status =
callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
wrap(&state), callbacks.userData);
switch (status) {
case MlirDiagnosedSilenceableFailureSuccess:
return DiagnosedSilenceableFailure::success();
case MlirDiagnosedSilenceableFailureSilenceableFailure:
// TODO: enable passing diagnostic info from C API to C++ API.
return DiagnosedSilenceableFailure::silenceableFailure(std::move(
*(op->emitError()
<< "TransformOpInterfaceFallbackModel: silenceable failure")
.getUnderlyingDiagnostic()));
case MlirDiagnosedSilenceableFailureDefiniteFailure:
return DiagnosedSilenceableFailure::definiteFailure();
}
llvm_unreachable("unknown transform status");
}
bool allowsRepeatedHandleOperands(Operation *op) const {
assert(callbacks.allowsRepeatedHandleOperands &&
"allowsRepeatedHandleOperands callback not set");
return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
}
private:
MlirTransformOpInterfaceCallbacks callbacks;
};
/// Attach a TransformOpInterface FallbackModel to the given named operation.
/// The FallbackModel uses the provided callbacks to implement the interface.
void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context.
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<TransformOpInterfaceFallbackModel>(
opInfo->getInterface<TransformOpInterfaceFallbackModel>());
assert(model && "Failed to get TransformOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
/// Set the effect for the operands to only read the transform handles.
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::onlyReadsHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the operands to consuming the transform handles.
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::consumesHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the results to that they produce transform handles.
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
MlirMemoryEffectInstancesList effects) {
// NB: calling `producesHandle()` `numResults` as we cannot cast array of
// `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
// to Python). `producesHandle` iterates over the given `ResultRange` anyway.
SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
for (intptr_t i = 0; i < numResults; ++i) {
auto opResult = cast<OpResult>(unwrap(results[i]));
transform::producesHandle(ResultRange(opResult), effectList);
}
}
/// Set the effect of potentially modifying payload IR.
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
transform::modifiesPayload(*unwrap(effects));
}
/// Set the effect of potentially reading payload IR.
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
transform::onlyReadsPayload(*unwrap(effects));
}