//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/Rewrite.h" #include "mlir-c/Support.h" #include "mlir-c/Transforms.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" using namespace mlir; //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder //===----------------------------------------------------------------------===// MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getContext()); } //===----------------------------------------------------------------------===// /// Insertion points methods //===----------------------------------------------------------------------===// void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { unwrap(rewriter)->clearInsertionPoint(); } void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->setInsertionPoint(unwrap(op)); } void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); } void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, MlirValue value) { unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); } void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); } void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); } MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getInsertionBlock()); } MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getBlock()); } MlirOperation mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) { mlir::RewriterBase *base = unwrap(rewriter); mlir::Block *block = base->getInsertionBlock(); mlir::Block::iterator it = base->getInsertionPoint(); if (it == block->end()) return {nullptr}; return wrap(std::addressof(*it)); } //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning //===----------------------------------------------------------------------===// MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, MlirType const *argTypes, MlirLocation const *locations) { SmallVector args; ArrayRef unwrappedArgs = unwrapList(nArgTypes, argTypes, args); SmallVector locs; ArrayRef unwrappedLocs = unwrapList(nArgTypes, locations, locs); return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, unwrappedLocs)); } MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->insert(unwrap(op))); } // Other methods of OpBuilder MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->clone(*unwrap(op))); } MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); } void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); } //===----------------------------------------------------------------------===// /// RewriterBase API //===----------------------------------------------------------------------===// void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); } void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nValues, values, vals); unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); } void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp) { unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); } void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->eraseOp(unwrap(op)); } void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->eraseBlock(unwrap(block)); } void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, MlirOperation op, intptr_t nArgValues, MlirValue const *argValues) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nArgValues, argValues, vals); unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), unwrappedVals); } void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, MlirBlock dest, intptr_t nArgValues, MlirValue const *argValues) { SmallVector args; ArrayRef unwrappedArgs = unwrapList(nArgValues, argValues, args); unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); } void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp) { unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); } void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp) { unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); } void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, MlirBlock existingBlock) { unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); } void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->startOpModification(unwrap(op)); } void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->finalizeOpModification(unwrap(op)); } void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->cancelOpModification(unwrap(op)); } void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, MlirValue to) { unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); } void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, MlirValue const *to) { SmallVector fromVals; ArrayRef unwrappedFromVals = unwrapList(nValues, from, fromVals); SmallVector toVals; ArrayRef unwrappedToVals = unwrapList(nValues, to, toVals); unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); } void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, MlirOperation from, intptr_t nTo, MlirValue const *to) { SmallVector toVals; ArrayRef unwrappedToVals = unwrapList(nTo, to, toVals); unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); } void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, MlirOperation from, MlirOperation to) { unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); } void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, MlirValue const *newValues, MlirBlock block) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nNewValues, newValues, vals); unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, unwrap(block)); } void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, MlirValue to, MlirOperation exceptedUser) { unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), unwrap(exceptedUser)); } //===----------------------------------------------------------------------===// /// IRRewriter API //===----------------------------------------------------------------------===// MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { return wrap(new IRRewriter(unwrap(context))); } MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { return wrap(new IRRewriter(unwrap(op))); } void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { delete static_cast(unwrap(rewriter)); } //===----------------------------------------------------------------------===// /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet set) { auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); set.ptr = nullptr; return wrap(m); } void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { delete unwrap(set); set.ptr = nullptr; } //===----------------------------------------------------------------------===// /// GreedyRewriteDriverConfig API //===----------------------------------------------------------------------===// inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) { assert(config.ptr && "unexpected null config"); return static_cast(config.ptr); } inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) { return {config}; } MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() { return wrap(new mlir::GreedyRewriteConfig()); } void mlirGreedyRewriteDriverConfigDestroy( MlirGreedyRewriteDriverConfig config) { delete unwrap(config); } void mlirGreedyRewriteDriverConfigSetMaxIterations( MlirGreedyRewriteDriverConfig config, int64_t maxIterations) { unwrap(config)->setMaxIterations(maxIterations); } void mlirGreedyRewriteDriverConfigSetMaxNumRewrites( MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) { unwrap(config)->setMaxNumRewrites(maxNumRewrites); } void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal( MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) { unwrap(config)->setUseTopDownTraversal(useTopDownTraversal); } void mlirGreedyRewriteDriverConfigEnableFolding( MlirGreedyRewriteDriverConfig config, bool enable) { unwrap(config)->enableFolding(enable); } void mlirGreedyRewriteDriverConfigSetStrictness( MlirGreedyRewriteDriverConfig config, MlirGreedyRewriteStrictness strictness) { mlir::GreedyRewriteStrictness cppStrictness; switch (strictness) { case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP: cppStrictness = mlir::GreedyRewriteStrictness::AnyOp; break; case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS: cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps; break; case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS: cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps; break; } unwrap(config)->setStrictness(cppStrictness); } void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level) { mlir::GreedySimplifyRegionLevel cppLevel; switch (level) { case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED: cppLevel = mlir::GreedySimplifyRegionLevel::Disabled; break; case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL: cppLevel = mlir::GreedySimplifyRegionLevel::Normal; break; case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE: cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive; break; } unwrap(config)->setRegionSimplificationLevel(cppLevel); } void mlirGreedyRewriteDriverConfigEnableConstantCSE( MlirGreedyRewriteDriverConfig config, bool enable) { unwrap(config)->enableConstantCSE(enable); } int64_t mlirGreedyRewriteDriverConfigGetMaxIterations( MlirGreedyRewriteDriverConfig config) { return unwrap(config)->getMaxIterations(); } int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites( MlirGreedyRewriteDriverConfig config) { return unwrap(config)->getMaxNumRewrites(); } bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal( MlirGreedyRewriteDriverConfig config) { return unwrap(config)->getUseTopDownTraversal(); } bool mlirGreedyRewriteDriverConfigIsFoldingEnabled( MlirGreedyRewriteDriverConfig config) { return unwrap(config)->isFoldingEnabled(); } MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness( MlirGreedyRewriteDriverConfig config) { mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness(); switch (cppStrictness) { case mlir::GreedyRewriteStrictness::AnyOp: return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP; case mlir::GreedyRewriteStrictness::ExistingAndNewOps: return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS; case mlir::GreedyRewriteStrictness::ExistingOps: return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS; } llvm_unreachable("Unknown GreedyRewriteStrictness"); } MlirGreedySimplifyRegionLevel mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel( MlirGreedyRewriteDriverConfig config) { mlir::GreedySimplifyRegionLevel cppLevel = unwrap(config)->getRegionSimplificationLevel(); switch (cppLevel) { case mlir::GreedySimplifyRegionLevel::Disabled: return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED; case mlir::GreedySimplifyRegionLevel::Normal: return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL; case mlir::GreedySimplifyRegionLevel::Aggressive: return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE; } llvm_unreachable("Unknown GreedySimplifyRegionLevel"); } bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled( MlirGreedyRewriteDriverConfig config) { return unwrap(config)->isConstantCSEEnabled(); } MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config) { return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns), *unwrap(config))); } MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig config) { return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns), *unwrap(config))); } void mlirWalkAndApplyPatterns(MlirOperation op, MlirFrozenRewritePatternSet patterns) { mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns)); } MlirLogicalResult mlirApplyPartialConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config) { return wrap(mlir::applyPartialConversion(unwrap(op), *unwrap(target), *unwrap(patterns), *unwrap(config))); } MlirLogicalResult mlirApplyFullConversion(MlirOperation op, MlirConversionTarget target, MlirFrozenRewritePatternSet patterns, MlirConversionConfig config) { return wrap(mlir::applyFullConversion(unwrap(op), *unwrap(target), *unwrap(patterns), *unwrap(config))); } //===----------------------------------------------------------------------===// /// ConversionConfig API //===----------------------------------------------------------------------===// MlirConversionConfig mlirConversionConfigCreate(void) { return wrap(new mlir::ConversionConfig()); } void mlirConversionConfigDestroy(MlirConversionConfig config) { delete unwrap(config); } void mlirConversionConfigSetFoldingMode(MlirConversionConfig config, MlirDialectConversionFoldingMode mode) { mlir::DialectConversionFoldingMode cppMode; switch (mode) { case MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER: cppMode = mlir::DialectConversionFoldingMode::Never; break; case MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS: cppMode = mlir::DialectConversionFoldingMode::BeforePatterns; break; case MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS: cppMode = mlir::DialectConversionFoldingMode::AfterPatterns; break; } unwrap(config)->foldingMode = cppMode; } MlirDialectConversionFoldingMode mlirConversionConfigGetFoldingMode(MlirConversionConfig config) { switch (unwrap(config)->foldingMode) { case mlir::DialectConversionFoldingMode::Never: return MLIR_DIALECT_CONVERSION_FOLDING_MODE_NEVER; case mlir::DialectConversionFoldingMode::BeforePatterns: return MLIR_DIALECT_CONVERSION_FOLDING_MODE_BEFORE_PATTERNS; case mlir::DialectConversionFoldingMode::AfterPatterns: return MLIR_DIALECT_CONVERSION_FOLDING_MODE_AFTER_PATTERNS; } } void mlirConversionConfigEnableBuildMaterializations( MlirConversionConfig config, bool enable) { unwrap(config)->buildMaterializations = enable; } bool mlirConversionConfigIsBuildMaterializationsEnabled( MlirConversionConfig config) { return unwrap(config)->buildMaterializations; } //===----------------------------------------------------------------------===// /// PatternRewriter API //===----------------------------------------------------------------------===// MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { return wrap(static_cast(unwrap(rewriter))); } //===----------------------------------------------------------------------===// /// ConversionPatternRewriter API //===----------------------------------------------------------------------===// MlirPatternRewriter mlirConversionPatternRewriterAsPatternRewriter( MlirConversionPatternRewriter rewriter) { return wrap(static_cast(unwrap(rewriter))); } MlirLogicalResult mlirConversionPatternRewriterConvertRegionTypes( MlirConversionPatternRewriter rewriter, MlirRegion region, MlirTypeConverter typeConverter) { return wrap(unwrap(rewriter)->convertRegionTypes(unwrap(region), *unwrap(typeConverter))); } //===----------------------------------------------------------------------===// /// ConversionTarget API //===----------------------------------------------------------------------===// MlirConversionTarget mlirConversionTargetCreate(MlirContext context) { return wrap(new mlir::ConversionTarget(*unwrap(context))); } void mlirConversionTargetDestroy(MlirConversionTarget target) { delete unwrap(target); } void mlirConversionTargetAddLegalOp(MlirConversionTarget target, MlirStringRef opName) { unwrap(target)->addLegalOp( mlir::OperationName(unwrap(opName), &unwrap(target)->getContext())); } void mlirConversionTargetAddIllegalOp(MlirConversionTarget target, MlirStringRef opName) { unwrap(target)->addIllegalOp( mlir::OperationName(unwrap(opName), &unwrap(target)->getContext())); } void mlirConversionTargetAddLegalDialect(MlirConversionTarget target, MlirStringRef dialectName) { unwrap(target)->addLegalDialect(unwrap(dialectName)); } void mlirConversionTargetAddIllegalDialect(MlirConversionTarget target, MlirStringRef dialectName) { unwrap(target)->addIllegalDialect(unwrap(dialectName)); } //===----------------------------------------------------------------------===// /// TypeConverter API //===----------------------------------------------------------------------===// MlirTypeConverter mlirTypeConverterCreate() { return wrap(new mlir::TypeConverter()); } void mlirTypeConverterDestroy(MlirTypeConverter typeConverter) { delete unwrap(typeConverter); } void mlirTypeConverterAddConversion( MlirTypeConverter typeConverter, MlirTypeConverterConversionCallback convertType, void *userData) { unwrap(typeConverter) ->addConversion( [convertType, userData](Type type) -> std::optional { MlirType converted{nullptr}; MlirLogicalResult result = convertType(wrap(type), &converted, userData); if (mlirLogicalResultIsFailure(result)) return std::nullopt; // allowed to try another conversion function if (mlirTypeIsNull(converted)) return nullptr; return unwrap(converted); }); } MlirType mlirTypeConverterConvertType(MlirTypeConverter typeConverter, MlirType type) { return wrap(unwrap(typeConverter)->convertType(unwrap(type))); } //===----------------------------------------------------------------------===// /// ConversionPattern API //===----------------------------------------------------------------------===// namespace mlir { class ExternalConversionPattern : public mlir::ConversionPattern { public: ExternalConversionPattern(MlirConversionPatternCallbacks callbacks, void *userData, StringRef rootName, PatternBenefit benefit, MLIRContext *context, TypeConverter *typeConverter, ArrayRef generatedNames) : ConversionPattern(*typeConverter, rootName, benefit, context, generatedNames), callbacks(callbacks), userData(userData) { if (callbacks.construct) callbacks.construct(userData); } ~ExternalConversionPattern() { if (callbacks.destruct) callbacks.destruct(userData); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { std::vector wrappedOperands; for (Value val : operands) wrappedOperands.push_back(wrap(val)); return unwrap(callbacks.matchAndRewrite( wrap(static_cast(this)), wrap(op), wrappedOperands.size(), wrappedOperands.data(), wrap(&rewriter), userData)); } private: MlirConversionPatternCallbacks callbacks; void *userData; }; } // namespace mlir MlirConversionPattern mlirOpConversionPatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirTypeConverter typeConverter, MlirConversionPatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { std::vector generatedNamesVec; generatedNamesVec.reserve(nGeneratedNames); for (size_t i = 0; i < nGeneratedNames; ++i) generatedNamesVec.push_back(unwrap(generatedNames[i])); return wrap(new mlir::ExternalConversionPattern( callbacks, userData, unwrap(rootName), PatternBenefit(benefit), unwrap(context), unwrap(typeConverter), generatedNamesVec)); } MlirTypeConverter mlirConversionPatternGetTypeConverter(MlirConversionPattern pattern) { return wrap(const_cast(unwrap(pattern)->getTypeConverter())); } MlirRewritePattern mlirConversionPatternAsRewritePattern(MlirConversionPattern pattern) { return wrap(static_cast(unwrap(pattern))); } //===----------------------------------------------------------------------===// /// RewritePattern API //===----------------------------------------------------------------------===// namespace mlir { class ExternalRewritePattern : public mlir::RewritePattern { public: ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef generatedNames) : RewritePattern(rootName, benefit, context, generatedNames), callbacks(callbacks), userData(userData) { if (callbacks.construct) callbacks.construct(userData); } ~ExternalRewritePattern() { if (callbacks.destruct) callbacks.destruct(userData); } LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { return unwrap(callbacks.matchAndRewrite( wrap(static_cast(this)), wrap(op), wrap(&rewriter), userData)); } private: MlirRewritePatternCallbacks callbacks; void *userData; }; } // namespace mlir MlirRewritePattern mlirOpRewritePatternCreate( MlirStringRef rootName, unsigned benefit, MlirContext context, MlirRewritePatternCallbacks callbacks, void *userData, size_t nGeneratedNames, MlirStringRef *generatedNames) { std::vector generatedNamesVec; generatedNamesVec.reserve(nGeneratedNames); for (size_t i = 0; i < nGeneratedNames; ++i) { generatedNamesVec.push_back(unwrap(generatedNames[i])); } return wrap(new mlir::ExternalRewritePattern( callbacks, userData, unwrap(rootName), PatternBenefit(benefit), unwrap(context), generatedNamesVec)); } //===----------------------------------------------------------------------===// /// RewritePatternSet API //===----------------------------------------------------------------------===// MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { return wrap(new mlir::RewritePatternSet(unwrap(context))); } MlirContext mlirRewritePatternSetGetContext(MlirRewritePatternSet set) { return wrap(unwrap(set)->getContext()); } void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { delete unwrap(set); } void mlirRewritePatternSetAdd(MlirRewritePatternSet set, MlirRewritePattern pattern) { std::unique_ptr patternPtr( const_cast(unwrap(pattern))); pattern.ptr = nullptr; unwrap(set)->add(std::move(patternPtr)); } //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef(unwrap(op)))); } void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { delete unwrap(op); op.ptr = nullptr; } MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); op.ptr = nullptr; return wrap(m); } MlirValue mlirPDLValueAsValue(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast()); } MlirType mlirPDLValueAsType(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast()); } MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast()); } MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { return wrap(unwrap(value)->dyn_cast()); } void mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value) { unwrap(results)->push_back(unwrap(value)); } void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { unwrap(results)->push_back(unwrap(value)); } void mlirPDLResultListPushBackOperation(MlirPDLResultList results, MlirOperation value) { unwrap(results)->push_back(unwrap(value)); } void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, MlirAttribute value) { unwrap(results)->push_back(unwrap(value)); } inline std::vector wrap(ArrayRef values) { std::vector mlirValues; mlirValues.reserve(values.size()); for (auto &value : values) { mlirValues.push_back(wrap(&value)); } return mlirValues; } void mlirPDLPatternModuleRegisterRewriteFunction( MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLRewriteFunction rewriteFn, void *userData) { unwrap(pdlModule)->registerRewriteFunction( unwrap(name), [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, ArrayRef values) -> LogicalResult { std::vector mlirValues = wrap(values); return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), mlirValues.size(), mlirValues.data(), userData)); }); } void mlirPDLPatternModuleRegisterConstraintFunction( MlirPDLPatternModule pdlModule, MlirStringRef name, MlirPDLConstraintFunction constraintFn, void *userData) { unwrap(pdlModule)->registerConstraintFunction( unwrap(name), [userData, constraintFn](PatternRewriter &rewriter, PDLResultList &results, ArrayRef values) -> LogicalResult { std::vector mlirValues = wrap(values); return unwrap(constraintFn(wrap(&rewriter), wrap(&results), mlirValues.size(), mlirValues.data(), userData)); }); } #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH