//===- ExtensibleDialect - C API for MLIR Extensible Dialect --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/ExtensibleDialect.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OperationSupport.h" using namespace mlir; DEFINE_C_API_PTR_METHODS(MlirDynamicOpTrait, DynamicOpTrait) DEFINE_C_API_PTR_METHODS(MlirDynamicTypeDefinition, DynamicTypeDefinition) DEFINE_C_API_PTR_METHODS(MlirDynamicAttrDefinition, DynamicAttrDefinition) bool mlirDynamicOpTraitAttach(MlirDynamicOpTrait dynamicOpTrait, MlirStringRef opName, MlirContext context) { std::optional opNameFound = RegisteredOperationName::lookup(unwrap(opName), unwrap(context)); assert(opNameFound && "operation name must be registered in the context"); // The original getImpl() is protected, so we create a small helper struct // here. struct RegisteredOperationNameWithImpl : RegisteredOperationName { Impl *getImpl() { return RegisteredOperationName::getImpl(); } }; OperationName::Impl *impl = static_cast(*opNameFound).getImpl(); std::unique_ptr trait(unwrap(dynamicOpTrait)); // TODO: we should enable llvm-style RTTI for `OperationName::Impl` and check // whether the `impl` is a `DynamicOpDefinition` here. return static_cast(impl)->addTrait(std::move(trait)); } MlirDynamicOpTrait mlirDynamicOpTraitIsTerminatorCreate() { return wrap(new DynamicOpTraits::IsTerminator()); } MlirTypeID mlirDynamicOpTraitIsTerminatorGetTypeID() { return wrap(DynamicOpTraits::IsTerminator::getStaticTypeID()); } MlirDynamicOpTrait mlirDynamicOpTraitNoTerminatorCreate() { return wrap(new DynamicOpTraits::NoTerminator()); } MlirTypeID mlirDynamicOpTraitNoTerminatorGetTypeID() { return wrap(DynamicOpTraits::NoTerminator::getStaticTypeID()); } void mlirDynamicOpTraitDestroy(MlirDynamicOpTrait dynamicOpTrait) { delete unwrap(dynamicOpTrait); } namespace mlir { class ExternalDynamicOpTrait : public DynamicOpTrait { public: ExternalDynamicOpTrait(TypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData) : typeID(typeID), callbacks(callbacks), userData(userData) { if (callbacks.construct) callbacks.construct(userData); } ~ExternalDynamicOpTrait() { if (callbacks.destruct) callbacks.destruct(userData); } LogicalResult verifyTrait(Operation *op) const override { return unwrap(callbacks.verifyTrait(wrap(op), userData)); }; LogicalResult verifyRegionTrait(Operation *op) const override { return unwrap(callbacks.verifyRegionTrait(wrap(op), userData)); }; TypeID getTypeID() const override { return typeID; }; private: TypeID typeID; MlirDynamicOpTraitCallbacks callbacks; void *userData; }; } // namespace mlir MlirDynamicOpTrait mlirDynamicOpTraitCreate( MlirTypeID typeID, MlirDynamicOpTraitCallbacks callbacks, void *userData) { return wrap( new mlir::ExternalDynamicOpTrait(unwrap(typeID), callbacks, userData)); } bool mlirDialectIsAExtensibleDialect(MlirDialect dialect) { return llvm::isa(unwrap(dialect)); } MlirDynamicTypeDefinition mlirExtensibleDialectLookupTypeDefinition(MlirDialect dialect, MlirStringRef typeName) { return wrap(llvm::cast(unwrap(dialect)) ->lookupTypeDefinition(unwrap(typeName))); } bool mlirTypeIsADynamicType(MlirType type) { return llvm::isa(unwrap(type)); } MlirTypeID mlirDynamicTypeGetTypeID() { return wrap(mlir::DynamicType::getTypeID()); } MlirType mlirDynamicTypeGet(MlirDynamicTypeDefinition typeDef, MlirAttribute *attrs, intptr_t numAttrs) { llvm::SmallVector attributes; attributes.reserve(numAttrs); for (intptr_t i = 0; i < numAttrs; ++i) attributes.push_back(unwrap(attrs[i])); return wrap(mlir::DynamicType::get(unwrap(typeDef), attributes)); } intptr_t mlirDynamicTypeGetNumParams(MlirType type) { return llvm::cast(unwrap(type)).getParams().size(); } MlirAttribute mlirDynamicTypeGetParam(MlirType type, intptr_t index) { return wrap(llvm::cast(unwrap(type)).getParams()[index]); } MlirDynamicTypeDefinition mlirDynamicTypeGetTypeDef(MlirType type) { return wrap(llvm::cast(unwrap(type)).getTypeDef()); } MlirTypeID mlirDynamicTypeDefinitionGetTypeID(MlirDynamicTypeDefinition typeDef) { return wrap(unwrap(typeDef)->getTypeID()); } MlirStringRef mlirDynamicTypeDefinitionGetName(MlirDynamicTypeDefinition typeDef) { return wrap(unwrap(typeDef)->getName()); } MlirDialect mlirDynamicTypeDefinitionGetDialect(MlirDynamicTypeDefinition typeDef) { return wrap(unwrap(typeDef)->getDialect()); } MlirDynamicAttrDefinition mlirExtensibleDialectLookupAttrDefinition(MlirDialect dialect, MlirStringRef attrName) { return wrap(llvm::cast(unwrap(dialect)) ->lookupAttrDefinition(unwrap(attrName))); } bool mlirAttributeIsADynamicAttr(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } MlirTypeID mlirDynamicAttrGetTypeID(void) { return wrap(mlir::DynamicAttr::getTypeID()); } MlirAttribute mlirDynamicAttrGet(MlirDynamicAttrDefinition attrDef, MlirAttribute *attrs, intptr_t numAttrs) { llvm::SmallVector attributes; attributes.reserve(numAttrs); for (intptr_t i = 0; i < numAttrs; ++i) attributes.push_back(unwrap(attrs[i])); return wrap(mlir::DynamicAttr::get(unwrap(attrDef), attributes)); } intptr_t mlirDynamicAttrGetNumParams(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getParams().size(); } MlirAttribute mlirDynamicAttrGetParam(MlirAttribute attr, intptr_t index) { return wrap(llvm::cast(unwrap(attr)).getParams()[index]); } MlirDynamicAttrDefinition mlirDynamicAttrGetAttrDef(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getAttrDef()); } MlirTypeID mlirDynamicAttrDefinitionGetTypeID(MlirDynamicAttrDefinition attrDef) { return wrap(unwrap(attrDef)->getTypeID()); } MlirStringRef mlirDynamicAttrDefinitionGetName(MlirDynamicAttrDefinition attrDef) { return wrap(unwrap(attrDef)->getName()); } MlirDialect mlirDynamicAttrDefinitionGetDialect(MlirDynamicAttrDefinition attrDef) { return wrap(unwrap(attrDef)->getDialect()); }