Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter. With assistance of Claude.
424 lines
16 KiB
C++
424 lines
16 KiB
C++
//===- Transform.cpp - C Interface for Transform dialect ------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Dialect/Transform.h"
|
|
#include "mlir-c/Support.h"
|
|
#include "mlir/CAPI/Dialect/Transform.h"
|
|
#include "mlir/CAPI/Interfaces.h"
|
|
#include "mlir/CAPI/Registration.h"
|
|
#include "mlir/CAPI/Rewrite.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
|
|
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
|
|
transform::TransformDialect)
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// AnyOpType
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlirTypeIsATransformAnyOpType(MlirType type) {
|
|
return isa<transform::AnyOpType>(unwrap(type));
|
|
}
|
|
|
|
MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
|
|
return wrap(transform::AnyOpType::getTypeID());
|
|
}
|
|
|
|
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
|
|
return wrap(transform::AnyOpType::get(unwrap(ctx)));
|
|
}
|
|
|
|
MlirStringRef mlirTransformAnyOpTypeGetName(void) {
|
|
return wrap(transform::AnyOpType::name);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// AnyParamType
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlirTypeIsATransformAnyParamType(MlirType type) {
|
|
return isa<transform::AnyParamType>(unwrap(type));
|
|
}
|
|
|
|
MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
|
|
return wrap(transform::AnyParamType::getTypeID());
|
|
}
|
|
|
|
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
|
|
return wrap(transform::AnyParamType::get(unwrap(ctx)));
|
|
}
|
|
|
|
MlirStringRef mlirTransformAnyParamTypeGetName(void) {
|
|
return wrap(transform::AnyParamType::name);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// AnyValueType
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlirTypeIsATransformAnyValueType(MlirType type) {
|
|
return isa<transform::AnyValueType>(unwrap(type));
|
|
}
|
|
|
|
MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
|
|
return wrap(transform::AnyValueType::getTypeID());
|
|
}
|
|
|
|
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
|
|
return wrap(transform::AnyValueType::get(unwrap(ctx)));
|
|
}
|
|
|
|
MlirStringRef mlirTransformAnyValueTypeGetName(void) {
|
|
return wrap(transform::AnyValueType::name);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// OperationType
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlirTypeIsATransformOperationType(MlirType type) {
|
|
return isa<transform::OperationType>(unwrap(type));
|
|
}
|
|
|
|
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
|
|
return wrap(transform::OperationType::getTypeID());
|
|
}
|
|
|
|
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
|
|
MlirStringRef operationName) {
|
|
return wrap(
|
|
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
|
|
}
|
|
|
|
MlirStringRef mlirTransformOperationTypeGetName(void) {
|
|
return wrap(transform::OperationType::name);
|
|
}
|
|
|
|
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
|
|
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// ParamType
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
bool mlirTypeIsATransformParamType(MlirType type) {
|
|
return isa<transform::ParamType>(unwrap(type));
|
|
}
|
|
|
|
MlirTypeID mlirTransformParamTypeGetTypeID(void) {
|
|
return wrap(transform::ParamType::getTypeID());
|
|
}
|
|
|
|
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
|
|
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
|
|
}
|
|
|
|
MlirStringRef mlirTransformParamTypeGetName(void) {
|
|
return wrap(transform::ParamType::name);
|
|
}
|
|
|
|
MlirType mlirTransformParamTypeGetType(MlirType type) {
|
|
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// TransformRewriter
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
|
|
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
|
|
mlir::transform::TransformRewriter *t = unwrap(rewriter);
|
|
mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
|
|
return wrap(base);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// TransformResults
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
|
|
intptr_t numOps, MlirOperation *ops) {
|
|
SmallVector<Operation *> opsVec;
|
|
opsVec.reserve(numOps);
|
|
for (intptr_t i = 0; i < numOps; ++i)
|
|
opsVec.push_back(unwrap(ops[i]));
|
|
unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
|
|
}
|
|
|
|
void mlirTransformResultsSetValues(MlirTransformResults results,
|
|
MlirValue result, intptr_t numValues,
|
|
MlirValue *values) {
|
|
SmallVector<Value> valuesVec;
|
|
valuesVec.reserve(numValues);
|
|
for (intptr_t i = 0; i < numValues; ++i)
|
|
valuesVec.push_back(unwrap(values[i]));
|
|
unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
|
|
}
|
|
|
|
void mlirTransformResultsSetParams(MlirTransformResults results,
|
|
MlirValue result, intptr_t numParams,
|
|
MlirAttribute *params) {
|
|
SmallVector<Attribute> paramsVec;
|
|
paramsVec.reserve(numParams);
|
|
for (intptr_t i = 0; i < numParams; ++i)
|
|
paramsVec.push_back(unwrap(params[i]));
|
|
unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// TransformState
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
void mlirTransformStateForEachPayloadOp(MlirTransformState state,
|
|
MlirValue value,
|
|
MlirOperationCallback callback,
|
|
void *userData) {
|
|
for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
|
|
callback(wrap(op), userData);
|
|
}
|
|
|
|
void mlirTransformStateForEachPayloadValue(MlirTransformState state,
|
|
MlirValue value,
|
|
MlirValueCallback callback,
|
|
void *userData) {
|
|
for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
|
|
callback(wrap(val), userData);
|
|
}
|
|
|
|
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
|
|
MlirAttributeCallback callback,
|
|
void *userData) {
|
|
for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
|
|
callback(wrap(attr), userData);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// TransformOpInterface
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
MlirTypeID mlirTransformOpInterfaceTypeID(void) {
|
|
return wrap(transform::TransformOpInterface::getInterfaceID());
|
|
}
|
|
|
|
/// Fallback model for the TransformOpInterface that uses C API callbacks.
|
|
class TransformOpInterfaceFallbackModel
|
|
: public mlir::transform::TransformOpInterface::FallbackModel<
|
|
TransformOpInterfaceFallbackModel> {
|
|
public:
|
|
/// Sets the callbacks that this FallbackModel will use.
|
|
/// NB: the callbacks can only be set through this method as the
|
|
/// RegisteredOperationName::attachInterface mechanism default-constructs
|
|
/// the FallbackModel without being able to provide arguments.
|
|
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
|
|
this->callbacks = callbacks;
|
|
}
|
|
|
|
~TransformOpInterfaceFallbackModel() {
|
|
if (callbacks.destruct)
|
|
callbacks.destruct(callbacks.userData);
|
|
}
|
|
|
|
static TypeID getInterfaceID() {
|
|
return transform::TransformOpInterface::getInterfaceID();
|
|
}
|
|
|
|
static bool classof(const mlir::transform::detail::
|
|
TransformOpInterfaceInterfaceTraits::Concept *op) {
|
|
// Enable casting back to the FallbackModel from the Interface. This is
|
|
// necessary as attachInterface(...) default-constructs the FallbackModel
|
|
// without being able to pass in the callbacks and returns just the Concept.
|
|
return true;
|
|
}
|
|
|
|
::mlir::DiagnosedSilenceableFailure
|
|
apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
|
|
::mlir::transform::TransformResults &transformResults,
|
|
::mlir::transform::TransformState &state) const {
|
|
assert(callbacks.apply && "apply callback not set");
|
|
|
|
MlirDiagnosedSilenceableFailure status =
|
|
callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
|
|
wrap(&state), callbacks.userData);
|
|
|
|
switch (status) {
|
|
case MlirDiagnosedSilenceableFailureSuccess:
|
|
return DiagnosedSilenceableFailure::success();
|
|
case MlirDiagnosedSilenceableFailureSilenceableFailure:
|
|
// TODO: enable passing diagnostic info from C API to C++ API.
|
|
return DiagnosedSilenceableFailure::silenceableFailure(std::move(
|
|
*(op->emitError()
|
|
<< "TransformOpInterfaceFallbackModel: silenceable failure")
|
|
.getUnderlyingDiagnostic()));
|
|
case MlirDiagnosedSilenceableFailureDefiniteFailure:
|
|
return DiagnosedSilenceableFailure::definiteFailure();
|
|
}
|
|
llvm_unreachable("unknown transform status");
|
|
}
|
|
|
|
bool allowsRepeatedHandleOperands(Operation *op) const {
|
|
assert(callbacks.allowsRepeatedHandleOperands &&
|
|
"allowsRepeatedHandleOperands callback not set");
|
|
return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
|
|
}
|
|
|
|
private:
|
|
MlirTransformOpInterfaceCallbacks callbacks;
|
|
};
|
|
|
|
/// Attach a TransformOpInterface FallbackModel to the given named operation.
|
|
/// The FallbackModel uses the provided callbacks to implement the interface.
|
|
void mlirTransformOpInterfaceAttachFallbackModel(
|
|
MlirContext ctx, MlirStringRef opName,
|
|
MlirTransformOpInterfaceCallbacks callbacks) {
|
|
// Look up the operation definition in the context.
|
|
std::optional<RegisteredOperationName> opInfo =
|
|
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
|
|
|
|
assert(opInfo.has_value() && "operation not found in context");
|
|
|
|
// NB: the following default-constructs the FallbackModel _without_ being able
|
|
// to provide arguments.
|
|
opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
|
|
// Cast to get the underlying FallbackModel and set the callbacks.
|
|
auto *model = cast<TransformOpInterfaceFallbackModel>(
|
|
opInfo->getInterface<TransformOpInterfaceFallbackModel>());
|
|
|
|
assert(model && "Failed to get TransformOpInterfaceFallbackModel");
|
|
model->setCallbacks(callbacks);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// PatternDescriptorOpInterface
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
|
|
return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
|
|
}
|
|
|
|
/// Fallback model for the PatternDescriptorOpInterface that uses C API
|
|
/// callbacks.
|
|
class PatternDescriptorOpInterfaceFallbackModel
|
|
: public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
|
|
PatternDescriptorOpInterfaceFallbackModel> {
|
|
public:
|
|
/// Sets the callbacks that this FallbackModel will use.
|
|
/// NB: the callbacks can only be set through this method as the
|
|
/// RegisteredOperationName::attachInterface mechanism default-constructs
|
|
/// the FallbackModel without being able to provide arguments.
|
|
void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
|
|
this->callbacks = callbacks;
|
|
}
|
|
|
|
~PatternDescriptorOpInterfaceFallbackModel() {
|
|
if (callbacks.destruct)
|
|
callbacks.destruct(callbacks.userData);
|
|
}
|
|
|
|
static TypeID getInterfaceID() {
|
|
return transform::PatternDescriptorOpInterface::getInterfaceID();
|
|
}
|
|
|
|
static bool
|
|
classof(const mlir::transform::detail::
|
|
PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
|
|
// Enable casting back to the FallbackModel from the Interface. This is
|
|
// necessary as attachInterface(...) default-constructs the FallbackModel
|
|
// without being able to pass in the callbacks and returns just the Concept.
|
|
return true;
|
|
}
|
|
|
|
void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
|
|
assert(callbacks.populatePatterns && "populatePatterns callback not set");
|
|
callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
|
|
}
|
|
|
|
void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
|
|
transform::TransformState &state) const {
|
|
if (callbacks.populatePatternsWithState) {
|
|
callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
|
|
wrap(&state), callbacks.userData);
|
|
} else {
|
|
// Default implementation: call populatePatterns without state.
|
|
populatePatterns(op, patterns);
|
|
}
|
|
}
|
|
|
|
private:
|
|
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
|
|
};
|
|
|
|
/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
|
|
/// operation. The FallbackModel uses the provided callbacks to implement the
|
|
/// interface.
|
|
void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
|
|
MlirContext ctx, MlirStringRef opName,
|
|
MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
|
|
// Look up the operation definition in the context.
|
|
std::optional<RegisteredOperationName> opInfo =
|
|
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
|
|
|
|
assert(opInfo.has_value() && "operation not found in context");
|
|
|
|
// NB: the following default-constructs the FallbackModel _without_ being able
|
|
// to provide arguments.
|
|
opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
|
|
// Cast to get the underlying FallbackModel and set the callbacks.
|
|
auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
|
|
opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
|
|
|
|
assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
|
|
model->setCallbacks(callbacks);
|
|
}
|
|
|
|
//===---------------------------------------------------------------------===//
|
|
// MemoryEffectsOpInterface helpers
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
/// Set the effect for the operands to only read the transform handles.
|
|
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
|
|
MlirMemoryEffectInstancesList effects) {
|
|
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
|
|
transform::onlyReadsHandle(operandArray, *unwrap(effects));
|
|
}
|
|
|
|
/// Set the effect for the operands to consuming the transform handles.
|
|
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
|
|
MlirMemoryEffectInstancesList effects) {
|
|
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
|
|
transform::consumesHandle(operandArray, *unwrap(effects));
|
|
}
|
|
|
|
/// Set the effect for the results to that they produce transform handles.
|
|
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
|
|
MlirMemoryEffectInstancesList effects) {
|
|
// NB: calling `producesHandle()` `numResults` as we cannot cast array of
|
|
// `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
|
|
// to Python). `producesHandle` iterates over the given `ResultRange` anyway.
|
|
SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
|
|
for (intptr_t i = 0; i < numResults; ++i) {
|
|
auto opResult = cast<OpResult>(unwrap(results[i]));
|
|
transform::producesHandle(ResultRange(opResult), effectList);
|
|
}
|
|
}
|
|
|
|
/// Set the effect of potentially modifying payload IR.
|
|
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
|
|
transform::modifiesPayload(*unwrap(effects));
|
|
}
|
|
|
|
/// Set the effect of potentially reading payload IR.
|
|
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
|
|
transform::onlyReadsPayload(*unwrap(effects));
|
|
}
|