[MLIR][Python] Forward the name of MLIR types to Python side (#174700)

In this PR, I added a C API for each (upstream) MLIR type to retrieve
its type name (for example, `IntegerType` -> `mlirIntegerTypeGetName()`
-> `"builtin.integer"`), and exposed a corresponding `type_name` class
attribute in the Python bindings (e.g., `IntegerType.type_name` ->
`"builtin.integer"`). This can be used in various places to avoid
hard-coded strings, such as eliminating the manual string in
`irdl.base("!builtin.integer")`.

Note that parts of this PR (mainly mechanical changes) were produced via
GitHub Copilot and GPT-5.2. I have manually reviewed the changes and
verified them with tests to ensure correctness.
This commit is contained in:
Twice
2026-01-07 16:27:31 +08:00
committed by GitHub
parent eb13822b51
commit b919d62eae
31 changed files with 451 additions and 0 deletions

View File

@@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeGet(MlirContext ctx,
unsigned bitwidth);
MLIR_CAPI_EXPORTED MlirStringRef mlirIntegerTypeGetName(void);
/// Creates a signed integer type of the given bitwidth in the context. The type
/// is owned by the context.
MLIR_CAPI_EXPORTED MlirType mlirIntegerTypeSignedGet(MlirContext ctx,
@@ -69,6 +71,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirIndexTypeGetName(void);
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
@@ -89,6 +93,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat4E2M1FNTypeGetName(void);
/// Returns the typeID of an Float6E2M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
@@ -99,6 +105,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat6E2M3FNTypeGetName(void);
/// Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
@@ -109,6 +117,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat6E3M2FNTypeGetName(void);
/// Returns the typeID of an Float8E5M2 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void);
@@ -119,6 +129,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E5M2TypeGetName(void);
/// Returns the typeID of an Float8E4M3 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void);
@@ -129,6 +141,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E4M3TypeGetName(void);
/// Returns the typeID of an Float8E4M3FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void);
@@ -139,6 +153,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E4M3FNTypeGetName(void);
/// Returns the typeID of an Float8E5M2FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void);
@@ -149,6 +165,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E5M2FNUZTypeGetName(void);
/// Returns the typeID of an Float8E4M3FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void);
@@ -159,6 +177,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E4M3FNUZTypeGetName(void);
/// Returns the typeID of an Float8E4M3B11FNUZ type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void);
@@ -169,6 +189,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E4M3B11FNUZTypeGetName(void);
/// Returns the typeID of an Float8E3M4 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void);
@@ -179,6 +201,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E3M4TypeGetName(void);
/// Returns the typeID of an Float8E8M0FNU type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void);
@@ -189,6 +213,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirFloat8E8M0FNUTypeGetName(void);
/// Returns the typeID of an BFloat16 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void);
@@ -199,6 +225,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirBF16TypeGetName(void);
/// Returns the typeID of an Float16 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void);
@@ -209,6 +237,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirF16TypeGetName(void);
/// Returns the typeID of an Float32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void);
@@ -219,6 +249,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirF32TypeGetName(void);
/// Returns the typeID of an Float64 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void);
@@ -229,6 +261,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirF64TypeGetName(void);
/// Returns the typeID of a TF32 type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void);
@@ -239,6 +273,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirTF32TypeGetName(void);
//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//
@@ -253,6 +289,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type);
/// context.
MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirNoneTypeGetName(void);
//===----------------------------------------------------------------------===//
// Complex type.
//===----------------------------------------------------------------------===//
@@ -267,6 +305,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type);
/// the element type. The type is owned by the context.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGet(MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirComplexTypeGetName(void);
/// Returns the element type of the given complex type.
MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type);
@@ -341,6 +381,8 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGet(intptr_t rank,
const int64_t *shape,
MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirVectorTypeGetName(void);
/// Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
@@ -402,6 +444,8 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank,
MlirType elementType,
MlirAttribute encoding);
MLIR_CAPI_EXPORTED MlirStringRef mlirRankedTensorTypeGetName(void);
/// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(
@@ -416,6 +460,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type);
/// context as the element type. The type is owned by the context.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirUnrankedTensorTypeGetName(void);
/// Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType
/// on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType
@@ -446,6 +492,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType,
MlirAttribute layout,
MlirAttribute memorySpace);
MLIR_CAPI_EXPORTED MlirStringRef mlirMemRefTypeGetName(void);
/// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
@@ -471,6 +519,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked(
MLIR_CAPI_EXPORTED MlirType
mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace);
MLIR_CAPI_EXPORTED MlirStringRef mlirUnrankedMemRefTypeGetName(void);
/// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
/// MlirType on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
@@ -511,6 +561,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGet(MlirContext ctx,
intptr_t numElements,
MlirType const *elements);
MLIR_CAPI_EXPORTED MlirStringRef mlirTupleTypeGetName(void);
/// Returns the number of types contained in a tuple.
MLIR_CAPI_EXPORTED intptr_t mlirTupleTypeGetNumTypes(MlirType type);
@@ -534,6 +586,8 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGet(MlirContext ctx,
intptr_t numResults,
MlirType const *results);
MLIR_CAPI_EXPORTED MlirStringRef mlirFunctionTypeGetName(void);
/// Returns the number of input types.
MLIR_CAPI_EXPORTED intptr_t mlirFunctionTypeGetNumInputs(MlirType type);
@@ -565,6 +619,8 @@ MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx,
MlirStringRef dialectNamespace,
MlirStringRef typeData);
MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetName(void);
/// Returns the namespace of the dialect with which the given opaque type
/// is associated. The namespace string is owned by the context.
MLIR_CAPI_EXPORTED MlirStringRef

View File

@@ -29,6 +29,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID();
MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirAMDGPUTDMBaseTypeGetName(void);
//===---------------------------------------------------------------------===//
// TDMDescriptorType
//===---------------------------------------------------------------------===//
@@ -39,6 +41,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID();
MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirAMDGPUTDMDescriptorTypeGetName(void);
//===---------------------------------------------------------------------===//
// TDMGatherBaseType
//===---------------------------------------------------------------------===//
@@ -51,6 +55,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
MlirType elementType,
MlirType indexType);
MLIR_CAPI_EXPORTED MlirStringRef mlirAMDGPUTDMGatherBaseTypeGetName(void);
#ifdef __cplusplus
}
#endif

View File

@@ -41,6 +41,8 @@ MLIR_CAPI_EXPORTED MlirType mlirEmitCArrayTypeGet(intptr_t nDims,
int64_t *shape,
MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCArrayTypeGetName(void);
//===---------------------------------------------------------------------===//
// LValueType
//===---------------------------------------------------------------------===//
@@ -51,6 +53,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCLValueTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCLValueTypeGet(MlirType valueType);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCLValueTypeGetName(void);
//===---------------------------------------------------------------------===//
// OpaqueType
//===---------------------------------------------------------------------===//
@@ -62,6 +66,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCOpaqueTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx,
MlirStringRef value);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCOpaqueTypeGetName(void);
//===---------------------------------------------------------------------===//
// PointerType
//===---------------------------------------------------------------------===//
@@ -72,6 +78,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPointerTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCPointerTypeGet(MlirType pointee);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCPointerTypeGetName(void);
//===---------------------------------------------------------------------===//
// PtrDiffTType
//===---------------------------------------------------------------------===//
@@ -82,6 +90,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPtrDiffTTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCPtrDiffTTypeGetName(void);
//===---------------------------------------------------------------------===//
// SignedSizeTType
//===---------------------------------------------------------------------===//
@@ -92,6 +102,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSignedSizeTTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCSignedSizeTTypeGetName(void);
//===---------------------------------------------------------------------===//
// SizeTType
//===---------------------------------------------------------------------===//
@@ -102,6 +114,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSizeTTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirEmitCSizeTTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirEmitCSizeTTypeGetName(void);
//===----------------------------------------------------------------------===//
// CmpPredicate attribute.
//===----------------------------------------------------------------------===//

View File

@@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirGPUAsyncTokenTypeGetName(void);
//===---------------------------------------------------------------------===//
// ObjectAttr
//===---------------------------------------------------------------------===//

View File

@@ -23,6 +23,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm);
MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx,
unsigned addressSpace);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMPointerTypeGetName(void);
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMPointerTypeGetTypeID(void);
/// Returns `true` if the type is an LLVM dialect pointer type.
@@ -35,10 +37,14 @@ mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType);
/// Creates an llmv.void type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMVoidTypeGetName(void);
/// Creates an llvm.array type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType,
unsigned numElements);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMArrayTypeGetName(void);
/// Returns the element type of the llvm.array type.
MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGetElementType(MlirType type);
@@ -47,6 +53,8 @@ MLIR_CAPI_EXPORTED MlirType
mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
MlirType const *argumentTypes, bool isVarArg);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMFunctionTypeGetName(void);
/// Returns the number of input types.
MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type);
@@ -62,6 +70,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMStructTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetName(void);
/// Returns `true` if the type is a literal (unnamed) LLVM struct type.
MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type);

View File

@@ -29,6 +29,8 @@ MLIR_CAPI_EXPORTED MlirType mlirNVGPUTensorMapDescriptorTypeGet(
MlirContext ctx, MlirType tensorMemrefType, int swizzle, int l2promo,
int oobFill, int interleave);
MLIR_CAPI_EXPORTED MlirStringRef mlirNVGPUTensorMapDescriptorTypeGetName(void);
#ifdef __cplusplus
}
#endif

View File

@@ -34,6 +34,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPDLAttributeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirPDLAttributeTypeGetName(void);
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
@@ -44,6 +46,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPDLOperationTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirPDLOperationTypeGetName(void);
//===---------------------------------------------------------------------===//
// RangeType
//===---------------------------------------------------------------------===//
@@ -54,6 +58,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPDLRangeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
MLIR_CAPI_EXPORTED MlirStringRef mlirPDLRangeTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
//===---------------------------------------------------------------------===//
@@ -66,6 +72,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPDLTypeTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirPDLTypeTypeGetName(void);
//===---------------------------------------------------------------------===//
// ValueType
//===---------------------------------------------------------------------===//
@@ -76,6 +84,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirPDLValueTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirPDLValueTypeGetName(void);
#ifdef __cplusplus
}
#endif

View File

@@ -114,6 +114,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags,
int64_t storageTypeMin,
int64_t storageTypeMax);
MLIR_CAPI_EXPORTED MlirStringRef mlirAnyQuantizedTypeGetName(void);
//===---------------------------------------------------------------------===//
// UniformQuantizedType
//===---------------------------------------------------------------------===//
@@ -130,6 +132,8 @@ MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedTypeGet(
unsigned flags, MlirType storageType, MlirType expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
MLIR_CAPI_EXPORTED MlirStringRef mlirUniformQuantizedTypeGetName(void);
/// Returns the scale of the given uniform quantized type.
MLIR_CAPI_EXPORTED double mlirUniformQuantizedTypeGetScale(MlirType type);
@@ -157,6 +161,8 @@ MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedPerAxisTypeGet(
intptr_t nDims, double *scales, int64_t *zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax);
MLIR_CAPI_EXPORTED MlirStringRef mlirUniformQuantizedPerAxisTypeGetName(void);
/// Returns the number of axes in the given quantized per-axis type.
MLIR_CAPI_EXPORTED intptr_t
mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type);
@@ -200,6 +206,9 @@ MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet(
intptr_t blockSizeInfoLength, int32_t *quantizedDimensions,
int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax);
MLIR_CAPI_EXPORTED MlirStringRef
mlirUniformQuantizedSubChannelTypeGetName(void);
/// Returns the number of block sizes provided in type.
MLIR_CAPI_EXPORTED intptr_t
mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type);
@@ -236,6 +245,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirCalibratedQuantizedTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType
mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max);
MLIR_CAPI_EXPORTED MlirStringRef mlirCalibratedQuantizedTypeGetName(void);
/// Returns the min value of the given calibrated quantized type.
MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMin(MlirType type);

View File

@@ -46,18 +46,24 @@ MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx,
int32_t width);
MLIR_CAPI_EXPORTED MlirStringRef mlirSMTBitVectorTypeGetName(void);
/// Checks if the given type is a smt::BoolType.
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type);
/// Creates a smt::BoolType.
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirSMTBoolTypeGetName(void);
/// Checks if the given type is a smt::IntType.
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type);
/// Creates a smt::IntType.
MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirSMTIntTypeGetName(void);
/// Checks if the given type is a smt::FuncType.
MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type);

View File

@@ -29,6 +29,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyOpTypeGetName(void);
//===---------------------------------------------------------------------===//
// AnyParamType
//===---------------------------------------------------------------------===//
@@ -39,6 +41,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyParamTypeGetName(void);
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -49,6 +53,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformAnyValueTypeGetName(void);
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
@@ -60,6 +66,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetName(void);
MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);
@@ -74,6 +82,8 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
MlirType type);
MLIR_CAPI_EXPORTED MlirStringRef mlirTransformParamTypeGetName(void);
MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
#ifdef __cplusplus

View File

@@ -24,6 +24,7 @@
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Support.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
@@ -933,6 +934,7 @@ public:
using GetTypeIDFunctionTy = MlirTypeID (*)();
using Base = PyConcreteType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
static inline const MlirStringRef name{};
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
@@ -988,6 +990,12 @@ public:
/*replace*/ true);
}
if (DerivedTy::name.length != 0) {
cls.def_prop_ro_static("type_name", [](nanobind::object & /*self*/) {
return nanobind::str(DerivedTy::name.data, DerivedTy::name.length);
});
}
DerivedTy::bindDerived(cls);
}

View File

@@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/IRCore.h"
namespace mlir {
namespace python {
@@ -24,6 +25,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirIntegerTypeGetTypeID;
static constexpr const char *pyClassName = "IntegerType";
static inline const MlirStringRef name = mlirIntegerTypeGetName();
using PyConcreteType::PyConcreteType;
enum Signedness { Signless, Signed, Unsigned };
@@ -39,6 +41,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirIndexTypeGetTypeID;
static constexpr const char *pyClassName = "IndexType";
static inline const MlirStringRef name = mlirIndexTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -62,6 +65,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat4E2M1FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float4E2M1FNType";
static inline const MlirStringRef name = mlirFloat4E2M1FNTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -75,6 +79,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat6E2M3FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float6E2M3FNType";
static inline const MlirStringRef name = mlirFloat6E2M3FNTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -88,6 +93,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat6E3M2FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float6E3M2FNType";
static inline const MlirStringRef name = mlirFloat6E3M2FNTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -101,6 +107,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3FNTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNType";
static inline const MlirStringRef name = mlirFloat8E4M3FNTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -114,6 +121,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E5M2TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2Type";
static inline const MlirStringRef name = mlirFloat8E5M2TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -127,6 +135,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3Type";
static inline const MlirStringRef name = mlirFloat8E4M3TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -140,6 +149,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3FNUZType";
static inline const MlirStringRef name = mlirFloat8E4M3FNUZTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -153,6 +163,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E4M3B11FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
static inline const MlirStringRef name = mlirFloat8E4M3B11FNUZTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -166,6 +177,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E5M2FNUZTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E5M2FNUZType";
static inline const MlirStringRef name = mlirFloat8E5M2FNUZTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -179,6 +191,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E3M4TypeGetTypeID;
static constexpr const char *pyClassName = "Float8E3M4Type";
static inline const MlirStringRef name = mlirFloat8E3M4TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -192,6 +205,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat8E8M0FNUTypeGetTypeID;
static constexpr const char *pyClassName = "Float8E8M0FNUType";
static inline const MlirStringRef name = mlirFloat8E8M0FNUTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -205,6 +219,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirBFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "BF16Type";
static inline const MlirStringRef name = mlirBF16TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -218,6 +233,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat16TypeGetTypeID;
static constexpr const char *pyClassName = "F16Type";
static inline const MlirStringRef name = mlirF16TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -231,6 +247,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloatTF32TypeGetTypeID;
static constexpr const char *pyClassName = "FloatTF32Type";
static inline const MlirStringRef name = mlirTF32TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -244,6 +261,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat32TypeGetTypeID;
static constexpr const char *pyClassName = "F32Type";
static inline const MlirStringRef name = mlirF32TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -257,6 +275,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFloat64TypeGetTypeID;
static constexpr const char *pyClassName = "F64Type";
static inline const MlirStringRef name = mlirF64TypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -269,6 +288,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirNoneTypeGetTypeID;
static constexpr const char *pyClassName = "NoneType";
static inline const MlirStringRef name = mlirNoneTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -282,6 +302,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirComplexTypeGetTypeID;
static constexpr const char *pyClassName = "ComplexType";
static inline const MlirStringRef name = mlirComplexTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -309,6 +330,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirVectorTypeGetTypeID;
static constexpr const char *pyClassName = "VectorType";
static inline const MlirStringRef name = mlirVectorTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -334,6 +356,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "RankedTensorType";
static inline const MlirStringRef name = mlirRankedTensorTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -347,6 +370,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUnrankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedTensorType";
static inline const MlirStringRef name = mlirUnrankedTensorTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -360,6 +384,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "MemRefType";
static inline const MlirStringRef name = mlirMemRefTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -373,6 +398,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUnrankedMemRefTypeGetTypeID;
static constexpr const char *pyClassName = "UnrankedMemRefType";
static inline const MlirStringRef name = mlirUnrankedMemRefTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -386,6 +412,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTupleTypeGetTypeID;
static constexpr const char *pyClassName = "TupleType";
static inline const MlirStringRef name = mlirTupleTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -399,6 +426,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirFunctionTypeGetTypeID;
static constexpr const char *pyClassName = "FunctionType";
static inline const MlirStringRef name = mlirFunctionTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);
@@ -412,6 +440,7 @@ public:
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirOpaqueTypeGetTypeID;
static constexpr const char *pyClassName = "OpaqueType";
static inline const MlirStringRef name = mlirOpaqueTypeGetName();
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c);

View File

@@ -26,6 +26,7 @@ struct TDMBaseType : PyConcreteType<TDMBaseType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAMDGPUTDMBaseTypeGetTypeID;
static constexpr const char *pyClassName = "TDMBaseType";
static inline const MlirStringRef name = mlirAMDGPUTDMBaseTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -47,6 +48,7 @@ struct TDMDescriptorType : PyConcreteType<TDMDescriptorType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAMDGPUTDMDescriptorTypeGetTypeID;
static constexpr const char *pyClassName = "TDMDescriptorType";
static inline const MlirStringRef name = mlirAMDGPUTDMDescriptorTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -68,6 +70,7 @@ struct TDMGatherBaseType : PyConcreteType<TDMGatherBaseType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAMDGPUTDMGatherBaseTypeGetTypeID;
static constexpr const char *pyClassName = "TDMGatherBaseType";
static inline const MlirStringRef name = mlirAMDGPUTDMGatherBaseTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -29,6 +29,7 @@ namespace gpu {
struct AsyncTokenType : PyConcreteType<AsyncTokenType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAGPUAsyncTokenType;
static constexpr const char *pyClassName = "AsyncTokenType";
static inline const MlirStringRef name = mlirGPUAsyncTokenTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -37,6 +37,7 @@ struct StructType : PyConcreteType<StructType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMStructTypeGetTypeID;
static constexpr const char *pyClassName = "StructType";
static inline const MlirStringRef name = mlirLLVMStructTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -169,6 +170,7 @@ struct PointerType : PyConcreteType<PointerType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirLLVMPointerTypeGetTypeID;
static constexpr const char *pyClassName = "PointerType";
static inline const MlirStringRef name = mlirLLVMPointerTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -24,6 +24,8 @@ struct TensorMapDescriptorType : PyConcreteType<TensorMapDescriptorType> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsANVGPUTensorMapDescriptorType;
static constexpr const char *pyClassName = "TensorMapDescriptorType";
static inline const MlirStringRef name =
mlirNVGPUTensorMapDescriptorTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -42,6 +42,7 @@ struct AttributeType : PyConcreteType<AttributeType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLAttributeTypeGetTypeID;
static constexpr const char *pyClassName = "AttributeType";
static inline const MlirStringRef name = mlirPDLAttributeTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -65,6 +66,7 @@ struct OperationType : PyConcreteType<OperationType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLOperationTypeGetTypeID;
static constexpr const char *pyClassName = "OperationType";
static inline const MlirStringRef name = mlirPDLOperationTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -88,6 +90,7 @@ struct RangeType : PyConcreteType<RangeType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLRangeTypeGetTypeID;
static constexpr const char *pyClassName = "RangeType";
static inline const MlirStringRef name = mlirPDLRangeTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -118,6 +121,7 @@ struct TypeType : PyConcreteType<TypeType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLTypeTypeGetTypeID;
static constexpr const char *pyClassName = "TypeType";
static inline const MlirStringRef name = mlirPDLTypeTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -141,6 +145,7 @@ struct ValueType : PyConcreteType<ValueType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPDLValueTypeGetTypeID;
static constexpr const char *pyClassName = "ValueType";
static inline const MlirStringRef name = mlirPDLValueTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -198,6 +198,7 @@ struct AnyQuantizedType : PyConcreteType<AnyQuantizedType, QuantizedType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirAnyQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "AnyQuantizedType";
static inline const MlirStringRef name = mlirAnyQuantizedTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -229,6 +230,7 @@ struct UniformQuantizedType
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedType";
static inline const MlirStringRef name = mlirUniformQuantizedTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -283,6 +285,8 @@ struct UniformQuantizedPerAxisType
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedPerAxisTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedPerAxisType";
static inline const MlirStringRef name =
mlirUniformQuantizedPerAxisTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -369,6 +373,8 @@ struct UniformQuantizedSubChannelType
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirUniformQuantizedSubChannelTypeGetTypeID;
static constexpr const char *pyClassName = "UniformQuantizedSubChannelType";
static inline const MlirStringRef name =
mlirUniformQuantizedSubChannelTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -462,6 +468,7 @@ struct CalibratedQuantizedType
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirCalibratedQuantizedTypeGetTypeID;
static constexpr const char *pyClassName = "CalibratedQuantizedType";
static inline const MlirStringRef name = mlirCalibratedQuantizedTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -30,6 +30,7 @@ namespace smt {
struct BoolType : PyConcreteType<BoolType> {
static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABool;
static constexpr const char *pyClassName = "BoolType";
static inline const MlirStringRef name = mlirSMTBoolTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -46,6 +47,7 @@ struct BoolType : PyConcreteType<BoolType> {
struct BitVectorType : PyConcreteType<BitVectorType> {
static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsABitVector;
static constexpr const char *pyClassName = "BitVectorType";
static inline const MlirStringRef name = mlirSMTBitVectorTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -63,6 +65,7 @@ struct BitVectorType : PyConcreteType<BitVectorType> {
struct IntType : PyConcreteType<IntType> {
static constexpr IsAFunctionTy isaFunction = mlirSMTTypeIsAInt;
static constexpr const char *pyClassName = "IntType";
static inline const MlirStringRef name = mlirSMTIntTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -31,6 +31,7 @@ struct AnyOpType : PyConcreteType<AnyOpType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyOpTypeGetTypeID;
static constexpr const char *pyClassName = "AnyOpType";
static inline const MlirStringRef name = mlirTransformAnyOpTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -54,6 +55,7 @@ struct AnyParamType : PyConcreteType<AnyParamType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyParamTypeGetTypeID;
static constexpr const char *pyClassName = "AnyParamType";
static inline const MlirStringRef name = mlirTransformAnyParamTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -77,6 +79,7 @@ struct AnyValueType : PyConcreteType<AnyValueType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformAnyValueTypeGetTypeID;
static constexpr const char *pyClassName = "AnyValueType";
static inline const MlirStringRef name = mlirTransformAnyValueTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -101,6 +104,7 @@ struct OperationType : PyConcreteType<OperationType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformOperationTypeGetTypeID;
static constexpr const char *pyClassName = "OperationType";
static inline const MlirStringRef name = mlirTransformOperationTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {
@@ -136,6 +140,7 @@ struct ParamType : PyConcreteType<ParamType> {
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirTransformParamTypeGetTypeID;
static constexpr const char *pyClassName = "ParamType";
static inline const MlirStringRef name = mlirTransformParamTypeGetName();
using Base::Base;
static void bindDerived(ClassTy &c) {

View File

@@ -32,6 +32,10 @@ MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
}
MlirStringRef mlirAMDGPUTDMBaseTypeGetName(void) {
return wrap(amdgpu::TDMBaseType::name);
}
//===---------------------------------------------------------------------===//
// TDMDescriptorType
//===---------------------------------------------------------------------===//
@@ -48,6 +52,10 @@ MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
}
MlirStringRef mlirAMDGPUTDMDescriptorTypeGetName(void) {
return wrap(amdgpu::TDMDescriptorType::name);
}
//===---------------------------------------------------------------------===//
// TDMGatherBaseType
//===---------------------------------------------------------------------===//
@@ -65,3 +73,7 @@ MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),
unwrap(indexType)));
}
MlirStringRef mlirAMDGPUTDMGatherBaseTypeGetName(void) {
return wrap(amdgpu::TDMGatherBaseType::name);
}

View File

@@ -49,6 +49,10 @@ MlirType mlirEmitCArrayTypeGet(intptr_t nDims, int64_t *shape,
emitc::ArrayType::get(llvm::ArrayRef(shape, nDims), unwrap(elementType)));
}
MlirStringRef mlirEmitCArrayTypeGetName(void) {
return wrap(emitc::ArrayType::name);
}
//===---------------------------------------------------------------------===//
// LValueType
//===---------------------------------------------------------------------===//
@@ -65,6 +69,10 @@ MlirType mlirEmitCLValueTypeGet(MlirType valueType) {
return wrap(emitc::LValueType::get(unwrap(valueType)));
}
MlirStringRef mlirEmitCLValueTypeGetName(void) {
return wrap(emitc::LValueType::name);
}
//===---------------------------------------------------------------------===//
// OpaqueType
//===---------------------------------------------------------------------===//
@@ -81,6 +89,10 @@ MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx, MlirStringRef value) {
return wrap(emitc::OpaqueType::get(unwrap(ctx), unwrap(value)));
}
MlirStringRef mlirEmitCOpaqueTypeGetName(void) {
return wrap(emitc::OpaqueType::name);
}
//===---------------------------------------------------------------------===//
// PointerType
//===---------------------------------------------------------------------===//
@@ -97,6 +109,10 @@ MlirType mlirEmitCPointerTypeGet(MlirType pointee) {
return wrap(emitc::PointerType::get(unwrap(pointee)));
}
MlirStringRef mlirEmitCPointerTypeGetName(void) {
return wrap(emitc::PointerType::name);
}
//===---------------------------------------------------------------------===//
// PtrDiffTType
//===---------------------------------------------------------------------===//
@@ -113,6 +129,10 @@ MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx) {
return wrap(emitc::PtrDiffTType::get(unwrap(ctx)));
}
MlirStringRef mlirEmitCPtrDiffTTypeGetName(void) {
return wrap(emitc::PtrDiffTType::name);
}
//===---------------------------------------------------------------------===//
// SignedSizeTType
//===---------------------------------------------------------------------===//
@@ -129,6 +149,10 @@ MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx) {
return wrap(emitc::SignedSizeTType::get(unwrap(ctx)));
}
MlirStringRef mlirEmitCSignedSizeTTypeGetName(void) {
return wrap(emitc::SignedSizeTType::name);
}
//===---------------------------------------------------------------------===//
// SizeTType
//===---------------------------------------------------------------------===//
@@ -145,6 +169,10 @@ MlirType mlirEmitCSizeTTypeGet(MlirContext ctx) {
return wrap(emitc::SizeTType::get(unwrap(ctx)));
}
MlirStringRef mlirEmitCSizeTTypeGetName(void) {
return wrap(emitc::SizeTType::name);
}
//===----------------------------------------------------------------------===//
// CmpPredicate attribute.
//===----------------------------------------------------------------------===//

View File

@@ -27,6 +27,10 @@ MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) {
return wrap(gpu::AsyncTokenType::get(unwrap(ctx)));
}
MlirStringRef mlirGPUAsyncTokenTypeGetName(void) {
return wrap(gpu::AsyncTokenType::name);
}
//===---------------------------------------------------------------------===//
// ObjectAttr
//===---------------------------------------------------------------------===//

View File

@@ -27,6 +27,10 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) {
return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace));
}
MlirStringRef mlirLLVMPointerTypeGetName(void) {
return wrap(LLVM::LLVMPointerType::name);
}
MlirTypeID mlirLLVMPointerTypeGetTypeID() {
return wrap(LLVM::LLVMPointerType::getTypeID());
}
@@ -43,10 +47,16 @@ MlirType mlirLLVMVoidTypeGet(MlirContext ctx) {
return wrap(LLVMVoidType::get(unwrap(ctx)));
}
MlirStringRef mlirLLVMVoidTypeGetName(void) { return wrap(LLVMVoidType::name); }
MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) {
return wrap(LLVMArrayType::get(unwrap(elementType), numElements));
}
MlirStringRef mlirLLVMArrayTypeGetName(void) {
return wrap(LLVMArrayType::name);
}
MlirType mlirLLVMArrayTypeGetElementType(MlirType type) {
return wrap(cast<LLVM::LLVMArrayType>(unwrap(type)).getElementType());
}
@@ -59,6 +69,10 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes,
unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg));
}
MlirStringRef mlirLLVMFunctionTypeGetName(void) {
return wrap(LLVMFunctionType::name);
}
intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) {
return llvm::cast<LLVM::LLVMFunctionType>(unwrap(type)).getNumParams();
}
@@ -81,6 +95,10 @@ MlirTypeID mlirLLVMStructTypeGetTypeID() {
return wrap(LLVM::LLVMStructType::getTypeID());
}
MlirStringRef mlirLLVMStructTypeGetName(void) {
return wrap(LLVM::LLVMStructType::name);
}
bool mlirLLVMStructTypeIsLiteral(MlirType type) {
return !cast<LLVM::LLVMStructType>(unwrap(type)).isIdentified();
}

View File

@@ -29,3 +29,7 @@ MlirType mlirNVGPUTensorMapDescriptorTypeGet(MlirContext ctx,
TensorMapSwizzleKind(swizzle), TensorMapL2PromoKind(l2promo),
TensorMapOOBKind(oobFill), TensorMapInterleaveKind(interleave)));
}
MlirStringRef mlirNVGPUTensorMapDescriptorTypeGetName(void) {
return wrap(nvgpu::TensorMapDescriptorType::name);
}

View File

@@ -40,6 +40,10 @@ MlirType mlirPDLAttributeTypeGet(MlirContext ctx) {
return wrap(pdl::AttributeType::get(unwrap(ctx)));
}
MlirStringRef mlirPDLAttributeTypeGetName(void) {
return wrap(pdl::AttributeType::name);
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
@@ -56,6 +60,10 @@ MlirType mlirPDLOperationTypeGet(MlirContext ctx) {
return wrap(pdl::OperationType::get(unwrap(ctx)));
}
MlirStringRef mlirPDLOperationTypeGetName(void) {
return wrap(pdl::OperationType::name);
}
//===---------------------------------------------------------------------===//
// RangeType
//===---------------------------------------------------------------------===//
@@ -72,6 +80,10 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) {
return wrap(pdl::RangeType::get(unwrap(elementType)));
}
MlirStringRef mlirPDLRangeTypeGetName(void) {
return wrap(pdl::RangeType::name);
}
MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
return wrap(cast<pdl::RangeType>(unwrap(type)).getElementType());
}
@@ -92,6 +104,8 @@ MlirType mlirPDLTypeTypeGet(MlirContext ctx) {
return wrap(pdl::TypeType::get(unwrap(ctx)));
}
MlirStringRef mlirPDLTypeTypeGetName(void) { return wrap(pdl::TypeType::name); }
//===---------------------------------------------------------------------===//
// ValueType
//===---------------------------------------------------------------------===//
@@ -107,3 +121,7 @@ MlirTypeID mlirPDLValueTypeGetTypeID(void) {
MlirType mlirPDLValueTypeGet(MlirContext ctx) {
return wrap(pdl::ValueType::get(unwrap(ctx)));
}
MlirStringRef mlirPDLValueTypeGetName(void) {
return wrap(pdl::ValueType::name);
}

View File

@@ -125,6 +125,10 @@ MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType,
storageTypeMin, storageTypeMax));
}
MlirStringRef mlirAnyQuantizedTypeGetName(void) {
return wrap(quant::AnyQuantizedType::name);
}
//===---------------------------------------------------------------------===//
// UniformQuantizedType
//===---------------------------------------------------------------------===//
@@ -146,6 +150,10 @@ MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType,
storageTypeMin, storageTypeMax));
}
MlirStringRef mlirUniformQuantizedTypeGetName(void) {
return wrap(quant::UniformQuantizedType::name);
}
double mlirUniformQuantizedTypeGetScale(MlirType type) {
return cast<quant::UniformQuantizedType>(unwrap(type)).getScale();
}
@@ -181,6 +189,10 @@ MlirType mlirUniformQuantizedPerAxisTypeGet(
quantizedDimension, storageTypeMin, storageTypeMax));
}
MlirStringRef mlirUniformQuantizedPerAxisTypeGetName(void) {
return wrap(quant::UniformQuantizedPerAxisType::name);
}
intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) {
return cast<quant::UniformQuantizedPerAxisType>(unwrap(type))
.getScales()
@@ -238,6 +250,10 @@ MlirType mlirUniformQuantizedSubChannelTypeGet(
storageTypeMax));
}
MlirStringRef mlirUniformQuantizedSubChannelTypeGetName(void) {
return wrap(quant::UniformQuantizedSubChannelType::name);
}
intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) {
return cast<quant::UniformQuantizedSubChannelType>(unwrap(type))
.getBlockSizes()
@@ -284,6 +300,10 @@ MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min,
quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max));
}
MlirStringRef mlirCalibratedQuantizedTypeGetName(void) {
return wrap(quant::CalibratedQuantizedType::name);
}
double mlirCalibratedQuantizedTypeGetMin(MlirType type) {
return cast<quant::CalibratedQuantizedType>(unwrap(type)).getMin();
}

View File

@@ -49,18 +49,26 @@ MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) {
return wrap(BitVectorType::get(unwrap(ctx), width));
}
MlirStringRef mlirSMTBitVectorTypeGetName(void) {
return wrap(BitVectorType::name);
}
bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); }
MlirType mlirSMTTypeGetBool(MlirContext ctx) {
return wrap(BoolType::get(unwrap(ctx)));
}
MlirStringRef mlirSMTBoolTypeGetName(void) { return wrap(BoolType::name); }
bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); }
MlirType mlirSMTTypeGetInt(MlirContext ctx) {
return wrap(IntType::get(unwrap(ctx)));
}
MlirStringRef mlirSMTIntTypeGetName(void) { return wrap(IntType::name); }
bool mlirSMTTypeIsASMTFunc(MlirType type) {
return isa<SMTFuncType>(unwrap(type));
}

View File

@@ -33,6 +33,10 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyOpTypeGetName(void) {
return wrap(transform::AnyOpType::name);
}
//===---------------------------------------------------------------------===//
// AnyParamType
//===---------------------------------------------------------------------===//
@@ -49,6 +53,10 @@ MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyParamTypeGetName(void) {
return wrap(transform::AnyParamType::name);
}
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -65,6 +73,10 @@ MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyValueTypeGetName(void) {
return wrap(transform::AnyValueType::name);
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
@@ -83,6 +95,10 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}
MlirStringRef mlirTransformOperationTypeGetName(void) {
return wrap(transform::OperationType::name);
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
@@ -103,6 +119,10 @@ MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
MlirStringRef mlirTransformParamTypeGetName(void) {
return wrap(transform::ParamType::name);
}
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}

View File

@@ -35,6 +35,8 @@ MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth));
}
MlirStringRef mlirIntegerTypeGetName(void) { return wrap(IntegerType::name); }
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
}
@@ -73,6 +75,8 @@ MlirType mlirIndexTypeGet(MlirContext ctx) {
return wrap(IndexType::get(unwrap(ctx)));
}
MlirStringRef mlirIndexTypeGetName(void) { return wrap(IndexType::name); }
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
@@ -97,6 +101,10 @@ MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
return wrap(Float4E2M1FNType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat4E2M1FNTypeGetName(void) {
return wrap(Float4E2M1FNType::name);
}
MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
return wrap(Float6E2M3FNType::getTypeID());
}
@@ -109,6 +117,10 @@ MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
return wrap(Float6E2M3FNType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat6E2M3FNTypeGetName(void) {
return wrap(Float6E2M3FNType::name);
}
MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
return wrap(Float6E3M2FNType::getTypeID());
}
@@ -121,6 +133,10 @@ MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
return wrap(Float6E3M2FNType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat6E3M2FNTypeGetName(void) {
return wrap(Float6E3M2FNType::name);
}
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
@@ -133,6 +149,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(Float8E5M2Type::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E5M2TypeGetName(void) {
return wrap(Float8E5M2Type::name);
}
MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
return wrap(Float8E4M3Type::getTypeID());
}
@@ -145,6 +165,10 @@ MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
return wrap(Float8E4M3Type::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E4M3TypeGetName(void) {
return wrap(Float8E4M3Type::name);
}
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
return wrap(Float8E4M3FNType::getTypeID());
}
@@ -157,6 +181,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(Float8E4M3FNType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E4M3FNTypeGetName(void) {
return wrap(Float8E4M3FNType::name);
}
MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
return wrap(Float8E5M2FNUZType::getTypeID());
}
@@ -169,6 +197,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(Float8E5M2FNUZType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E5M2FNUZTypeGetName(void) {
return wrap(Float8E5M2FNUZType::name);
}
MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
return wrap(Float8E4M3FNUZType::getTypeID());
}
@@ -181,6 +213,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(Float8E4M3FNUZType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E4M3FNUZTypeGetName(void) {
return wrap(Float8E4M3FNUZType::name);
}
MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
return wrap(Float8E4M3B11FNUZType::getTypeID());
}
@@ -193,6 +229,10 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
return wrap(Float8E4M3B11FNUZType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E4M3B11FNUZTypeGetName(void) {
return wrap(Float8E4M3B11FNUZType::name);
}
MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
return wrap(Float8E3M4Type::getTypeID());
}
@@ -205,6 +245,10 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
return wrap(Float8E3M4Type::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E3M4TypeGetName(void) {
return wrap(Float8E3M4Type::name);
}
MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
return wrap(Float8E8M0FNUType::getTypeID());
}
@@ -217,6 +261,10 @@ MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
return wrap(Float8E8M0FNUType::get(unwrap(ctx)));
}
MlirStringRef mlirFloat8E8M0FNUTypeGetName(void) {
return wrap(Float8E8M0FNUType::name);
}
MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
@@ -229,6 +277,8 @@ MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(BFloat16Type::get(unwrap(ctx)));
}
MlirStringRef mlirBF16TypeGetName(void) { return wrap(BFloat16Type::name); }
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
bool mlirTypeIsAF16(MlirType type) {
@@ -239,6 +289,8 @@ MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(Float16Type::get(unwrap(ctx)));
}
MlirStringRef mlirF16TypeGetName(void) { return wrap(Float16Type::name); }
MlirTypeID mlirFloatTF32TypeGetTypeID() {
return wrap(FloatTF32Type::getTypeID());
}
@@ -251,6 +303,8 @@ MlirType mlirTF32TypeGet(MlirContext ctx) {
return wrap(FloatTF32Type::get(unwrap(ctx)));
}
MlirStringRef mlirTF32TypeGetName(void) { return wrap(FloatTF32Type::name); }
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
bool mlirTypeIsAF32(MlirType type) {
@@ -261,6 +315,8 @@ MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(Float32Type::get(unwrap(ctx)));
}
MlirStringRef mlirF32TypeGetName(void) { return wrap(Float32Type::name); }
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
bool mlirTypeIsAF64(MlirType type) {
@@ -271,6 +327,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(Float64Type::get(unwrap(ctx)));
}
MlirStringRef mlirF64TypeGetName(void) { return wrap(Float64Type::name); }
//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//
@@ -285,6 +343,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) {
return wrap(NoneType::get(unwrap(ctx)));
}
MlirStringRef mlirNoneTypeGetName(void) { return wrap(NoneType::name); }
//===----------------------------------------------------------------------===//
// Complex type.
//===----------------------------------------------------------------------===//
@@ -299,6 +359,8 @@ MlirType mlirComplexTypeGet(MlirType elementType) {
return wrap(ComplexType::get(unwrap(elementType)));
}
MlirStringRef mlirComplexTypeGetName(void) { return wrap(ComplexType::name); }
MlirType mlirComplexTypeGetElementType(MlirType type) {
return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
}
@@ -380,6 +442,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
unwrap(elementType)));
}
MlirStringRef mlirVectorTypeGetName(void) { return wrap(VectorType::name); }
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
@@ -443,6 +507,10 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
unwrap(elementType), unwrap(encoding)));
}
MlirStringRef mlirRankedTensorTypeGetName(void) {
return wrap(RankedTensorType::name);
}
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
MlirType elementType,
@@ -460,6 +528,10 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
MlirStringRef mlirUnrankedTensorTypeGetName(void) {
return wrap(UnrankedTensorType::name);
}
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
MlirType elementType) {
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
@@ -486,6 +558,8 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
unwrap(memorySpace)));
}
MlirStringRef mlirMemRefTypeGetName(void) { return wrap(MemRefType::name); }
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
intptr_t rank, const int64_t *shape,
MlirAttribute layout,
@@ -554,6 +628,10 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
}
MlirStringRef mlirUnrankedMemRefTypeGetName(void) {
return wrap(UnrankedMemRefType::name);
}
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
MlirType elementType,
MlirAttribute memorySpace) {
@@ -582,6 +660,8 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
return wrap(TupleType::get(unwrap(ctx), typeRef));
}
MlirStringRef mlirTupleTypeGetName(void) { return wrap(TupleType::name); }
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
return llvm::cast<TupleType>(unwrap(type)).size();
}
@@ -613,6 +693,8 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
}
MlirStringRef mlirFunctionTypeGetName(void) { return wrap(FunctionType::name); }
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
return llvm::cast<FunctionType>(unwrap(type)).getNumInputs();
}
@@ -650,6 +732,8 @@ MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
unwrap(typeData)));
}
MlirStringRef mlirOpaqueTypeGetName(void) { return wrap(OpaqueType::name); }
MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
return wrap(
llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());

View File

@@ -759,6 +759,49 @@ def testTypeIDs():
print(ShapedType(vector_type).typeid == vector_type.typeid)
# CHECK-LABEL: TEST: testTypeName
@run
def testTypeName():
with Context():
# CHECK: builtin.integer
print(IntegerType.type_name)
# CHECK: builtin.index
print(IndexType.type_name)
# CHECK: builtin.f32
print(F32Type.type_name)
# CHECK: builtin.bf16
print(BF16Type.type_name)
# CHECK: builtin.none
print(NoneType.type_name)
# CHECK: builtin.complex
print(ComplexType.type_name)
# CHECK: builtin.vector
print(VectorType.type_name)
# CHECK: builtin.tensor
print(RankedTensorType.type_name)
# CHECK: builtin.unranked_tensor
print(UnrankedTensorType.type_name)
# CHECK: builtin.memref
print(MemRefType.type_name)
# CHECK: builtin.unranked_memref
print(UnrankedMemRefType.type_name)
# CHECK: builtin.tuple
print(TupleType.type_name)
# CHECK: builtin.function
print(FunctionType.type_name)
# CHECK: builtin.opaque
print(OpaqueType.type_name)
# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
@run
def testConcreteTypesRoundTrip():