Makes it possible to include Python-defined rewrite patterns in transform-dialect schedules, inside of `transform.apply_patterns`, which upon execution of the schedule runs the pattern in a greedy rewriter. With assistance of Claude.
842 lines
32 KiB
C++
842 lines
32 KiB
C++
//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir-c/Rewrite.h"
|
|
|
|
#include "mlir-c/Support.h"
|
|
#include "mlir-c/Transforms.h"
|
|
#include "mlir/CAPI/IR.h"
|
|
#include "mlir/CAPI/Rewrite.h"
|
|
#include "mlir/CAPI/Support.h"
|
|
#include "mlir/CAPI/Wrap.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/PDLPatternMatch.h.inc"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewriterBase API inherited from OpBuilder
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getContext());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Insertion points methods
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
|
|
unwrap(rewriter)->clearInsertionPoint();
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->setInsertionPoint(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
|
|
MlirValue value) {
|
|
unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
|
|
MlirBlock block) {
|
|
unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
|
|
MlirBlock block) {
|
|
unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
|
|
}
|
|
|
|
MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getInsertionBlock());
|
|
}
|
|
|
|
MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
|
|
return wrap(unwrap(rewriter)->getBlock());
|
|
}
|
|
|
|
MlirOperation
|
|
mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) {
|
|
mlir::RewriterBase *base = unwrap(rewriter);
|
|
mlir::Block *block = base->getInsertionBlock();
|
|
mlir::Block::iterator it = base->getInsertionPoint();
|
|
if (it == block->end())
|
|
return {nullptr};
|
|
|
|
return wrap(std::addressof(*it));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// Block and operation creation/insertion/cloning
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
|
|
MlirBlock insertBefore,
|
|
intptr_t nArgTypes,
|
|
MlirType const *argTypes,
|
|
MlirLocation const *locations) {
|
|
SmallVector<Type, 4> args;
|
|
ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
|
|
SmallVector<Location, 4> locs;
|
|
ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
|
|
return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
|
|
unwrappedLocs));
|
|
}
|
|
|
|
MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->insert(unwrap(op)));
|
|
}
|
|
|
|
// Other methods of OpBuilder
|
|
|
|
MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->clone(*unwrap(op)));
|
|
}
|
|
|
|
MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
|
|
}
|
|
|
|
void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
|
|
MlirRegion region, MlirBlock before) {
|
|
|
|
unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewriterBase API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
|
|
MlirRegion region, MlirBlock before) {
|
|
unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
|
|
MlirOperation op, intptr_t nValues,
|
|
MlirValue const *values) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
|
|
unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
|
|
MlirOperation op,
|
|
MlirOperation newOp) {
|
|
unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
|
|
}
|
|
|
|
void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
|
|
unwrap(rewriter)->eraseOp(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
|
|
unwrap(rewriter)->eraseBlock(unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
|
|
MlirBlock source, MlirOperation op,
|
|
intptr_t nArgValues,
|
|
MlirValue const *argValues) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
|
|
|
|
unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
|
|
unwrappedVals);
|
|
}
|
|
|
|
void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
|
|
MlirBlock dest, intptr_t nArgValues,
|
|
MlirValue const *argValues) {
|
|
SmallVector<Value, 4> args;
|
|
ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
|
|
unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
|
|
}
|
|
|
|
void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
|
|
MlirOperation existingOp) {
|
|
unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
|
|
}
|
|
|
|
void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
|
|
MlirOperation existingOp) {
|
|
unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
|
|
}
|
|
|
|
void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
|
|
MlirBlock existingBlock) {
|
|
unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
|
|
}
|
|
|
|
void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->startOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->finalizeOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
|
|
MlirOperation op) {
|
|
unwrap(rewriter)->cancelOpModification(unwrap(op));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
|
|
MlirValue from, MlirValue to) {
|
|
unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
|
|
intptr_t nValues,
|
|
MlirValue const *from,
|
|
MlirValue const *to) {
|
|
SmallVector<Value, 4> fromVals;
|
|
ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
|
|
SmallVector<Value, 4> toVals;
|
|
ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
|
|
unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
|
|
MlirOperation from,
|
|
intptr_t nTo,
|
|
MlirValue const *to) {
|
|
SmallVector<Value, 4> toVals;
|
|
ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
|
|
unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
|
|
MlirOperation from,
|
|
MlirOperation to) {
|
|
unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
|
|
MlirOperation op,
|
|
intptr_t nNewValues,
|
|
MlirValue const *newValues,
|
|
MlirBlock block) {
|
|
SmallVector<Value, 4> vals;
|
|
ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
|
|
unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
|
|
unwrap(block));
|
|
}
|
|
|
|
void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
|
|
MlirValue from, MlirValue to,
|
|
MlirOperation exceptedUser) {
|
|
unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
|
|
unwrap(exceptedUser));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// IRRewriter API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
|
|
return wrap(new IRRewriter(unwrap(context)));
|
|
}
|
|
|
|
MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
|
|
return wrap(new IRRewriter(unwrap(op)));
|
|
}
|
|
|
|
void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
|
|
delete static_cast<IRRewriter *>(unwrap(rewriter));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewritePatternSet and FrozenRewritePatternSet API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirFrozenRewritePatternSet
|
|
mlirFreezeRewritePattern(MlirRewritePatternSet set) {
|
|
auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set)));
|
|
set.ptr = nullptr;
|
|
return wrap(m);
|
|
}
|
|
|
|
void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) {
|
|
delete unwrap(set);
|
|
set.ptr = nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// GreedyRewriteDriverConfig API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) {
|
|
assert(config.ptr && "unexpected null config");
|
|
return static_cast<mlir::GreedyRewriteConfig *>(config.ptr);
|
|
}
|
|
|
|
inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) {
|
|
return {config};
|
|
}
|
|
|
|
MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() {
|
|
return wrap(new mlir::GreedyRewriteConfig());
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigDestroy(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
delete unwrap(config);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigSetMaxIterations(
|
|
MlirGreedyRewriteDriverConfig config, int64_t maxIterations) {
|
|
unwrap(config)->setMaxIterations(maxIterations);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
|
|
MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) {
|
|
unwrap(config)->setMaxNumRewrites(maxNumRewrites);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
|
|
MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) {
|
|
unwrap(config)->setUseTopDownTraversal(useTopDownTraversal);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigEnableFolding(
|
|
MlirGreedyRewriteDriverConfig config, bool enable) {
|
|
unwrap(config)->enableFolding(enable);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigSetStrictness(
|
|
MlirGreedyRewriteDriverConfig config,
|
|
MlirGreedyRewriteStrictness strictness) {
|
|
mlir::GreedyRewriteStrictness cppStrictness;
|
|
switch (strictness) {
|
|
case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP:
|
|
cppStrictness = mlir::GreedyRewriteStrictness::AnyOp;
|
|
break;
|
|
case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS:
|
|
cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
|
|
break;
|
|
case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS:
|
|
cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps;
|
|
break;
|
|
}
|
|
unwrap(config)->setStrictness(cppStrictness);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
|
|
MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level) {
|
|
mlir::GreedySimplifyRegionLevel cppLevel;
|
|
switch (level) {
|
|
case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED:
|
|
cppLevel = mlir::GreedySimplifyRegionLevel::Disabled;
|
|
break;
|
|
case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL:
|
|
cppLevel = mlir::GreedySimplifyRegionLevel::Normal;
|
|
break;
|
|
case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE:
|
|
cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive;
|
|
break;
|
|
}
|
|
unwrap(config)->setRegionSimplificationLevel(cppLevel);
|
|
}
|
|
|
|
void mlirGreedyRewriteDriverConfigEnableConstantCSE(
|
|
MlirGreedyRewriteDriverConfig config, bool enable) {
|
|
unwrap(config)->enableConstantCSE(enable);
|
|
}
|
|
|
|
int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return unwrap(config)->getMaxIterations();
|
|
}
|
|
|
|
int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return unwrap(config)->getMaxNumRewrites();
|
|
}
|
|
|
|
bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return unwrap(config)->getUseTopDownTraversal();
|
|
}
|
|
|
|
bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return unwrap(config)->isFoldingEnabled();
|
|
}
|
|
|
|
MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness();
|
|
switch (cppStrictness) {
|
|
case mlir::GreedyRewriteStrictness::AnyOp:
|
|
return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP;
|
|
case mlir::GreedyRewriteStrictness::ExistingAndNewOps:
|
|
return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS;
|
|
case mlir::GreedyRewriteStrictness::ExistingOps:
|
|
return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS;
|
|
}
|
|
llvm_unreachable("Unknown GreedyRewriteStrictness");
|
|
}
|
|
|
|
MlirGreedySimplifyRegionLevel
|
|
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
mlir::GreedySimplifyRegionLevel cppLevel =
|
|
unwrap(config)->getRegionSimplificationLevel();
|
|
switch (cppLevel) {
|
|
case mlir::GreedySimplifyRegionLevel::Disabled:
|
|
return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED;
|
|
case mlir::GreedySimplifyRegionLevel::Normal:
|
|
return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL;
|
|
case mlir::GreedySimplifyRegionLevel::Aggressive:
|
|
return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE;
|
|
}
|
|
llvm_unreachable("Unknown GreedySimplifyRegionLevel");
|
|
}
|
|
|
|
bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return unwrap(config)->isConstantCSEEnabled();
|
|
}
|
|
|
|
MlirLogicalResult
|
|
mlirApplyPatternsAndFoldGreedily(MlirModule op,
|
|
MlirFrozenRewritePatternSet patterns,
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
|
|
*unwrap(config)));
|
|
}
|
|
|
|
MlirLogicalResult
|
|
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
|
|
MlirFrozenRewritePatternSet patterns,
|
|
MlirGreedyRewriteDriverConfig config) {
|
|
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
|
|
*unwrap(config)));
|
|
}
|
|
|
|
void mlirWalkAndApplyPatterns(MlirOperation op,
|
|
MlirFrozenRewritePatternSet patterns) {
|
|
mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns));
|
|
}
|
|
|
|
MlirLogicalResult
|
|
mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target,
|
|
MlirFrozenRewritePatternSet patterns,
|
|
MlirConversionConfig config) {
|
|
return wrap(mlir::applyPartialConversion(unwrap(op), *unwrap(target),
|
|
*unwrap(patterns), *unwrap(config)));
|
|
}
|
|
|
|
MlirLogicalResult mlirApplyFullConversion(MlirOperation op,
|
|
MlirConversionTarget target,
|
|
MlirFrozenRewritePatternSet patterns,
|
|
MlirConversionConfig config) {
|
|
return wrap(mlir::applyFullConversion(unwrap(op), *unwrap(target),
|
|
*unwrap(patterns), *unwrap(config)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ConversionConfig API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirConversionConfig mlirConversionConfigCreate(void) {
|
|
return wrap(new mlir::ConversionConfig());
|
|
}
|
|
|
|
void mlirConversionConfigDestroy(MlirConversionConfig config) {
|
|
delete unwrap(config);
|
|
}
|
|
|
|
void mlirConversionConfigSetFoldingMode(MlirConversionConfig config,
|
|
MlirDialectConversionFoldingMode mode) {
|
|
mlir::DialectConversionFoldingMode cppMode;
|
|
switch (mode) {
|
|
case MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER:
|
|
cppMode = mlir::DialectConversionFoldingMode::Never;
|
|
break;
|
|
case MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS:
|
|
cppMode = mlir::DialectConversionFoldingMode::BeforePatterns;
|
|
break;
|
|
case MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS:
|
|
cppMode = mlir::DialectConversionFoldingMode::AfterPatterns;
|
|
break;
|
|
}
|
|
unwrap(config)->foldingMode = cppMode;
|
|
}
|
|
|
|
MlirDialectConversionFoldingMode
|
|
mlirConversionConfigGetFoldingMode(MlirConversionConfig config) {
|
|
switch (unwrap(config)->foldingMode) {
|
|
case mlir::DialectConversionFoldingMode::Never:
|
|
return MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER;
|
|
case mlir::DialectConversionFoldingMode::BeforePatterns:
|
|
return MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS;
|
|
case mlir::DialectConversionFoldingMode::AfterPatterns:
|
|
return MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS;
|
|
}
|
|
}
|
|
|
|
void mlirConversionConfigEnableBuildMaterializations(
|
|
MlirConversionConfig config, bool enable) {
|
|
unwrap(config)->buildMaterializations = enable;
|
|
}
|
|
|
|
bool mlirConversionConfigIsBuildMaterializationsEnabled(
|
|
MlirConversionConfig config) {
|
|
return unwrap(config)->buildMaterializations;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// PatternRewriter API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) {
|
|
return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ConversionPatternRewriter API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter(
|
|
MlirConversionPatternRewriter rewriter) {
|
|
return wrap(static_cast<mlir::PatternRewriter *>(unwrap(rewriter)));
|
|
}
|
|
|
|
MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes(
|
|
MlirConversionPatternRewriter rewriter, MlirRegion region,
|
|
MlirTypeConverter typeConverter) {
|
|
return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region),
|
|
*unwrap(typeConverter)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ConversionTarget API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirConversionTarget mlirConversionTargetCreate(MlirContext context) {
|
|
return wrap(new mlir::ConversionTarget(*unwrap(context)));
|
|
}
|
|
|
|
void mlirConversionTargetDestroy(MlirConversionTarget target) {
|
|
delete unwrap(target);
|
|
}
|
|
|
|
void mlirConversionTargetAddLegalOp(MlirConversionTarget target,
|
|
MlirStringRef opName) {
|
|
unwrap(target)->addLegalOp(
|
|
mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
|
|
}
|
|
|
|
void mlirConversionTargetAddIllegalOp(MlirConversionTarget target,
|
|
MlirStringRef opName) {
|
|
unwrap(target)->addIllegalOp(
|
|
mlir::OperationName(unwrap(opName), &unwrap(target)->getContext()));
|
|
}
|
|
|
|
void mlirConversionTargetAddLegalDialect(MlirConversionTarget target,
|
|
MlirStringRef dialectName) {
|
|
unwrap(target)->addLegalDialect(unwrap(dialectName));
|
|
}
|
|
|
|
void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target,
|
|
MlirStringRef dialectName) {
|
|
unwrap(target)->addIllegalDialect(unwrap(dialectName));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// TypeConverter API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirTypeConverter mlirTypeConverterCreate() {
|
|
return wrap(new mlir::TypeConverter());
|
|
}
|
|
|
|
void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) {
|
|
delete unwrap(typeConverter);
|
|
}
|
|
|
|
void mlirTypeConverterAddConversion(
|
|
MlirTypeConverter typeConverter,
|
|
MlirTypeConverterConversionCallback convertType, void *userData) {
|
|
unwrap(typeConverter)
|
|
->addConversion(
|
|
[convertType, userData](Type type) -> std::optional<Type> {
|
|
MlirType converted{nullptr};
|
|
MlirLogicalResult result =
|
|
convertType(wrap(type), &converted, userData);
|
|
if (mlirLogicalResultIsFailure(result))
|
|
return std::nullopt; // allowed to try another conversion function
|
|
if (mlirTypeIsNull(converted))
|
|
return nullptr;
|
|
return unwrap(converted);
|
|
});
|
|
}
|
|
|
|
MlirType mlirTypeConverterConvertType(MlirTypeConverter typeConverter,
|
|
MlirType type) {
|
|
return wrap(unwrap(typeConverter)->convertType(unwrap(type)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// ConversionPattern API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
|
|
class ExternalConversionPattern : public mlir::ConversionPattern {
|
|
public:
|
|
ExternalConversionPattern(MlirConversionPatternCallbacks callbacks,
|
|
void *userData, StringRef rootName,
|
|
PatternBenefit benefit, MLIRContext *context,
|
|
TypeConverter *typeConverter,
|
|
ArrayRef<StringRef> generatedNames)
|
|
: ConversionPattern(*typeConverter, rootName, benefit, context,
|
|
generatedNames),
|
|
callbacks(callbacks), userData(userData) {
|
|
if (callbacks.construct)
|
|
callbacks.construct(userData);
|
|
}
|
|
|
|
~ExternalConversionPattern() {
|
|
if (callbacks.destruct)
|
|
callbacks.destruct(userData);
|
|
}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
std::vector<MlirValue> wrappedOperands;
|
|
for (Value val : operands)
|
|
wrappedOperands.push_back(wrap(val));
|
|
return unwrap(callbacks.matchAndRewrite(
|
|
wrap(static_cast<const mlir::ConversionPattern *>(this)), wrap(op),
|
|
wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter),
|
|
userData));
|
|
}
|
|
|
|
private:
|
|
MlirConversionPatternCallbacks callbacks;
|
|
void *userData;
|
|
};
|
|
|
|
} // namespace mlir
|
|
|
|
MlirConversionPattern mlirOpConversionPatternCreate(
|
|
MlirStringRef rootName, unsigned benefit, MlirContext context,
|
|
MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks,
|
|
void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) {
|
|
std::vector<mlir::StringRef> generatedNamesVec;
|
|
generatedNamesVec.reserve(nGeneratedNames);
|
|
for (size_t i = 0; i < nGeneratedNames; ++i)
|
|
generatedNamesVec.push_back(unwrap(generatedNames[i]));
|
|
return wrap(new mlir::ExternalConversionPattern(
|
|
callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
|
|
unwrap(context), unwrap(typeConverter), generatedNamesVec));
|
|
}
|
|
|
|
MlirTypeConverter
|
|
mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern) {
|
|
return wrap(const_cast<TypeConverter *>(unwrap(pattern)->getTypeConverter()));
|
|
}
|
|
|
|
MlirRewritePattern
|
|
mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern) {
|
|
return wrap(static_cast<const RewritePattern *>(unwrap(pattern)));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewritePattern API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace mlir {
|
|
|
|
class ExternalRewritePattern : public mlir::RewritePattern {
|
|
public:
|
|
ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData,
|
|
StringRef rootName, PatternBenefit benefit,
|
|
MLIRContext *context,
|
|
ArrayRef<StringRef> generatedNames)
|
|
: RewritePattern(rootName, benefit, context, generatedNames),
|
|
callbacks(callbacks), userData(userData) {
|
|
if (callbacks.construct)
|
|
callbacks.construct(userData);
|
|
}
|
|
|
|
~ExternalRewritePattern() {
|
|
if (callbacks.destruct)
|
|
callbacks.destruct(userData);
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(Operation *op,
|
|
PatternRewriter &rewriter) const override {
|
|
return unwrap(callbacks.matchAndRewrite(
|
|
wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op),
|
|
wrap(&rewriter), userData));
|
|
}
|
|
|
|
private:
|
|
MlirRewritePatternCallbacks callbacks;
|
|
void *userData;
|
|
};
|
|
|
|
} // namespace mlir
|
|
|
|
MlirRewritePattern mlirOpRewritePatternCreate(
|
|
MlirStringRef rootName, unsigned benefit, MlirContext context,
|
|
MlirRewritePatternCallbacks callbacks, void *userData,
|
|
size_t nGeneratedNames, MlirStringRef *generatedNames) {
|
|
std::vector<mlir::StringRef> generatedNamesVec;
|
|
generatedNamesVec.reserve(nGeneratedNames);
|
|
for (size_t i = 0; i < nGeneratedNames; ++i) {
|
|
generatedNamesVec.push_back(unwrap(generatedNames[i]));
|
|
}
|
|
return wrap(new mlir::ExternalRewritePattern(
|
|
callbacks, userData, unwrap(rootName), PatternBenefit(benefit),
|
|
unwrap(context), generatedNamesVec));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// RewritePatternSet API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) {
|
|
return wrap(new mlir::RewritePatternSet(unwrap(context)));
|
|
}
|
|
|
|
MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) {
|
|
return wrap(unwrap(set)->getContext());
|
|
}
|
|
|
|
void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) {
|
|
delete unwrap(set);
|
|
}
|
|
|
|
void mlirRewritePatternSetAdd(MlirRewritePatternSet set,
|
|
MlirRewritePattern pattern) {
|
|
std::unique_ptr<mlir::RewritePattern> patternPtr(
|
|
const_cast<mlir::RewritePattern *>(unwrap(pattern)));
|
|
pattern.ptr = nullptr;
|
|
unwrap(set)->add(std::move(patternPtr));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
/// PDLPatternModule API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
|
|
return wrap(new mlir::PDLPatternModule(
|
|
mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
|
|
}
|
|
|
|
void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
|
|
delete unwrap(op);
|
|
op.ptr = nullptr;
|
|
}
|
|
|
|
MlirRewritePatternSet
|
|
mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
|
|
auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
|
|
op.ptr = nullptr;
|
|
return wrap(m);
|
|
}
|
|
|
|
MlirValue mlirPDLValueAsValue(MlirPDLValue value) {
|
|
return wrap(unwrap(value)->dyn_cast<mlir::Value>());
|
|
}
|
|
|
|
MlirType mlirPDLValueAsType(MlirPDLValue value) {
|
|
return wrap(unwrap(value)->dyn_cast<mlir::Type>());
|
|
}
|
|
|
|
MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) {
|
|
return wrap(unwrap(value)->dyn_cast<mlir::Operation *>());
|
|
}
|
|
|
|
MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) {
|
|
return wrap(unwrap(value)->dyn_cast<mlir::Attribute>());
|
|
}
|
|
|
|
void mlirPDLResultListPushBackValue(MlirPDLResultList results,
|
|
MlirValue value) {
|
|
unwrap(results)->push_back(unwrap(value));
|
|
}
|
|
|
|
void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) {
|
|
unwrap(results)->push_back(unwrap(value));
|
|
}
|
|
|
|
void mlirPDLResultListPushBackOperation(MlirPDLResultList results,
|
|
MlirOperation value) {
|
|
unwrap(results)->push_back(unwrap(value));
|
|
}
|
|
|
|
void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
|
|
MlirAttribute value) {
|
|
unwrap(results)->push_back(unwrap(value));
|
|
}
|
|
|
|
inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
|
|
std::vector<MlirPDLValue> mlirValues;
|
|
mlirValues.reserve(values.size());
|
|
for (auto &value : values) {
|
|
mlirValues.push_back(wrap(&value));
|
|
}
|
|
return mlirValues;
|
|
}
|
|
|
|
void mlirPDLPatternModuleRegisterRewriteFunction(
|
|
MlirPDLPatternModule pdlModule, MlirStringRef name,
|
|
MlirPDLRewriteFunction rewriteFn, void *userData) {
|
|
unwrap(pdlModule)->registerRewriteFunction(
|
|
unwrap(name),
|
|
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
|
|
ArrayRef<PDLValue> values) -> LogicalResult {
|
|
std::vector<MlirPDLValue> mlirValues = wrap(values);
|
|
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
|
|
mlirValues.size(), mlirValues.data(),
|
|
userData));
|
|
});
|
|
}
|
|
|
|
void mlirPDLPatternModuleRegisterConstraintFunction(
|
|
MlirPDLPatternModule pdlModule, MlirStringRef name,
|
|
MlirPDLConstraintFunction constraintFn, void *userData) {
|
|
unwrap(pdlModule)->registerConstraintFunction(
|
|
unwrap(name),
|
|
[userData, constraintFn](PatternRewriter &rewriter,
|
|
PDLResultList &results,
|
|
ArrayRef<PDLValue> values) -> LogicalResult {
|
|
std::vector<MlirPDLValue> mlirValues = wrap(values);
|
|
return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
|
|
mlirValues.size(), mlirValues.data(),
|
|
userData));
|
|
});
|
|
}
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|