diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md index a023ee0ea1bb..181541a0f3a9 100644 --- a/mlir/Maintainers.md +++ b/mlir/Maintainers.md @@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas. * ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space)) * ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space)) * ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space)) -* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk)) * ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk)) * ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave)) diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md index 2bdf400a7759..2bcacbb4ee94 100644 --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -5,8 +5,8 @@ overall flow is two-stage: 1. **conversion** of the IR to a set of dialects translatable to LLVM IR, for example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific - dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md), - [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md); + dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md) + or [ArmNeon](Dialects/ArmNeon.md); 2. **translation** of MLIR dialects to LLVM IR. This flow allows the non-trivial transformation to be performed within MLIR diff --git a/mlir/include/mlir-c/Dialect/AMX.h b/mlir/include/mlir-c/Dialect/AMX.h deleted file mode 100644 index ac4695a107ae..000000000000 --- a/mlir/include/mlir-c/Dialect/AMX.h +++ /dev/null @@ -1,25 +0,0 @@ -//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_DIALECT_AMX_H -#define MLIR_C_DIALECT_AMX_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_DIALECT_AMX_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index ecc22abb0f93..e77860897399 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the - architectural-neutral vector dialect lowering. + (X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral + vector dialect lowering. }]; // Override explicitly in C++ to allow conditional dialect dependence. @@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "vector access are naturally aligned. If operations have an " "alignment attribute set, the alignment attribute takes priority " "over this option ">, - Option<"amx", "enable-amx", - "bool", /*default=*/"false", - "Enables the use of AMX dialect while lowering the vector " - "dialect.">, Option<"armNeon", "enable-arm-neon", "bool", /*default=*/"false", "Enables the use of ArmNeon dialect while lowering the vector " @@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> { //===----------------------------------------------------------------------===// def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> { - let summary = "Lower the operations from the vector dialect into the AMX " - "dialect"; + let summary = "Lower the operations from the vector dialect into the X86 " + "dialect AMX operations"; let dependentDialects = [ - "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect", + "affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect" ]; } diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h index b075ac92990a..6b178e02684c 100644 --- a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h +++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h @@ -1,4 +1,4 @@ -//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===// +//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -18,7 +18,7 @@ class RewritePatternSet; #define GEN_PASS_DECL_CONVERTVECTORTOAMX #include "mlir/Conversion/Passes.h.inc" -/// Collect a set of patterns to convert from the vector to AMX ops. +/// Collect a set of patterns to convert from the vector to X86 AMX ops. void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td deleted file mode 100644 index cace63d32fd8..000000000000 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ /dev/null @@ -1,440 +0,0 @@ -//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the basic operations for the AMX dialect. -// -// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix -// multiply unit (TMUL), a tile control register (TILECFG), and eight -// tile registers TMM0 through TMM7 (TILEDATA). -// -// The AMX dialect provides a bridge between MLIR concepts, such as -// 2-d vector, operations, and memrefs, and the lower level details -// of Intel AMX, such as configuration setup, tile sizes, instructions, -// and tile release. -// -// Note that since configuration changes (implicit at dialect level) are -// costly, it is highly recommended to use the AMX dialect on same-shaped -// vectors, at least within a single method. -// -// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html -// -//===----------------------------------------------------------------------===// - -#ifndef AMX -#define AMX - -include "mlir/Dialect/LLVMIR/LLVMOpBase.td" -include "mlir/Dialect/AMX/AMXInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinTypes.td" - -//===----------------------------------------------------------------------===// -// AMX dialect definition. -//===----------------------------------------------------------------------===// - -def AMX_Dialect : Dialect { - let name = "amx"; - let cppNamespace = "::mlir::amx"; - let description = [{ - The Intel Advanced Matrix Extensions (AMX) provide a tile matrix - multiply unit (TMUL), a tile control register (TILECFG), and eight - tile registers TMM0 through TMM7 (TILEDATA). - - This `AMX` dialect provides a bridge between MLIR concepts such as - vectors and memrefs and the lower level LLVM IR support of AMX. - - Note that since configuration changes (implicit at dialect level) are - costly, it is highly recommended to use the AMX dialect on same-shaped - vectors, at least within a single method. - - For details, see the Intel documentation: - https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html - }]; - let useDefaultTypePrinterParser = 1; -} - -//===----------------------------------------------------------------------===// -// AMX Tile definition. -//===----------------------------------------------------------------------===// - -class AMX_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> { - let cppFunctionName = "isValidTileTypeElementType"; -} - -def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> { - let summary = "AMX 2D tile to be used by AMX opertaions."; - - let description = [{ - This type is used to represent values in AMX tile registers. All AMX operations - work on AMX tiles and these tiles cannot be used in other operations directly. - LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and - element type for IR verification and lowering to LLVMIR dialect. - }]; - - let parameters = (ins - ArrayRefParameter<"int64_t">:$shape, - AMX_TileTypeElementType:$elementType - ); - - let builders = [ - TypeBuilderWithInferredContext<(ins - "ArrayRef":$shape, "Type":$elementType), [{ - return $_get(elementType.getContext(), shape, elementType); - }]> - ]; - - let extraClassDeclaration = [{ - /// Returns if this type is ranked (always true). - bool hasRank() const { return true; } - - /// Clone this tile type with the given shape and element type. If the - /// provided shape is `std::nullopt`, the current shape of the type is used. - TileType cloneWith(std::optional> shape, - Type elementType) const { - return get(shape.value_or(getShape()), elementType); - } - }]; - - let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; -} - -def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">, - CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>; - -class AMXTileOf allowedTypes> : - ShapedContainerType; - -def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>; - -def AMXTileF32 : AMXTileOf<[F32]>; - -def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>; - -def AMXTileI32 : AMXTileOf<[I32]>; - -def AMXTileI8 : AMXTileOf<[I8]>; - -//===----------------------------------------------------------------------===// -// AMX Op and IntrOp definitions. -//===----------------------------------------------------------------------===// - -class AMX_Op traits = []> : - Op {} - -//===----------------------------------------------------------------------===// -// AMX Op definitions -//===----------------------------------------------------------------------===// - -// -// Tile reset. -// - -def TileZeroOp : AMX_Op<"tile_zero", [ - AMXIntrinsicOpInterface, - MemoryEffects<[MemWrite]> - ]> { - let summary = "tile zero operation"; - let description = [{ - Zeroes the destination tile, with the shape defined by the 2-dim - vector type of the result. - - The operation is eventually lowered into the "tilezero" instruction - with the corresponding tile configuration. - - With the write memory effect, each `amx.tile_zero` operation serves as - a compilation hint to use a separate tile register. - - Example: - - ```mlir - %0 = amx.tile_zero : !amx.tile<16x16xbf16> - ``` - }]; - let results = (outs AnyAMXTile:$res); - let extraClassDeclaration = [{ - TileType getTileType() { - return ::llvm::cast(getRes().getType()); - } - - std::string getIntrinsicName() { - return "llvm.x86.tilezero.internal"; - } - SmallVector getIntrinsicOperands( - ::mlir::ArrayRef operands, - const ::mlir::LLVMTypeConverter &typeConverter, - ::mlir::RewriterBase &rewriter); - }]; - let assemblyFormat = "attr-dict `:` qualified(type($res))"; - let hasVerifier = 1; -} - -// -// Tile memory operations. -// - -def TileLoadOp : AMX_Op<"tile_load", [ - AMXIntrinsicOpInterface, - MemoryEffects<[MemWrite]>, - AttrSizedOperandSegments - ]> { - let summary = "tile load operation"; - let description = [{ - Loads a tile from memory defined by a `base` and `indices`, with the - shape defined by the 2-dim vector type of the result. - The tile's rows are populated by reading contiguous elements starting - at the `base`. For each tile row, the `base` is incremented by `stride` - number of elements. - - The tile is loaded using the following indexing scheme: - - ``` - for row in enumerate(tile_rows): - mem_row = base[i0, i1, ..., iN + row * stride] - for col in enumerate(tile_cols): - tile[row, col] = mem_row[col] - ``` - - If the `stride` is not provided, then the `base` buffer must be at least - 2-dimensional, and the `stride` is automatically inferred and corresponds - to the stride of the buffer's second innermost dimension. - - The operation is eventually lowered into the "tileloadd" instruction - with the corresponding tile configuration. - - With the write memory effect, each `amx.tile_load` operation serves as - a compilation hint to use a separate tile register. - - Example: - - ```mlir - // Tile load from a 2-D memref with implicit stride. - %0 = amx.tile_load %arg0[%c0, %c0] : memref into !amx.tile<16x64xi8> - - // Tile load from a 1-D memref with explicit stride. - %0 = amx.tile_load %arg0[%c0], %stride : memref into !amx.tile<16x64xi8> - ``` - }]; - let arguments = (ins Arg:$base, - Variadic:$indices, - Optional:$stride); - let results = (outs AnyAMXTile:$res); - let builders = [ - OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)> - ]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return ::llvm::cast(getBase().getType()); - } - TileType getTileType() { - return ::llvm::cast(getRes().getType()); - } - - std::string getIntrinsicName() { - return "llvm.x86.tileloadd64.internal"; - } - SmallVector getIntrinsicOperands( - ::mlir::ArrayRef operands, - const ::mlir::LLVMTypeConverter &typeConverter, - ::mlir::RewriterBase &rewriter); - }]; - let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict" - "`:` type($base) `into` qualified(type($res))"; - let hasVerifier = 1; -} - -def TileStoreOp : AMX_Op<"tile_store", [ - AMXIntrinsicOpInterface, - AttrSizedOperandSegments - ]> { - let summary = "tile store operation"; - let description = [{ - Stores a tile to memory defined by a `base` and `indices`, with the - shape defined by the 2-dim vector type of the value. - The tile's rows are written contiguously to the buffer starting at - the `base`. For each tile row, the `base` is incremented by `stride` - number of elements. - - The tile is stored using the following indexing scheme: - - ``` - for row in enumerate(tile_rows): - mem_row = base[i0, i1, ..., iN + row * stride] - for col in enumerate(tile_cols): - mem_row[col] = tile[row, col] - ``` - - If the `stride` is not provided, then the `base` buffer must be at least - 2-dimensional, and the `stride` is automatically inferred and corresponds - to the stride of the buffer's second innermost dimension. - - The operation is eventually lowered into the "tilestored" instruction - with the corresponding tile configuration. - - Example: - - ```mlir - // Tile store to a 2-D memref with implicit stride. - amx.tile_store %arg1[%c0, %c0], %0 : memref, !amx.tile<16x64xi8> - - // Tile store to a 1-D memref with explicit stride. - amx.tile_store %arg1[%c0], %0, %stride : memref, !amx.tile<16x64xi8> - ``` - }]; - let arguments = (ins Arg:$base, - Variadic:$indices, - AnyAMXTile:$val, - Optional:$stride); - let builders = [ - OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)> - ]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return ::llvm::cast(getBase().getType()); - } - TileType getTileType() { - return ::llvm::cast(getVal().getType()); - } - - std::string getIntrinsicName() { - return "llvm.x86.tilestored64.internal"; - } - SmallVector getIntrinsicOperands( - ::mlir::ArrayRef operands, - const ::mlir::LLVMTypeConverter &typeConverter, - ::mlir::RewriterBase &rewriter); - }]; - let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?" - "attr-dict `:` type($base) `,` qualified(type($val))"; - let hasVerifier = 1; -} - -// -// Tile arithmetic operations. -// - -def TileMulFOp : AMX_Op<"tile_mulf", [Pure, - AMXIntrinsicOpInterface, - AllTypesMatch<["acc", "res"]> - ]> { - let summary = "tile multiplication operation (floating-point)"; - let description = [{ - Multiplies a "m x k" tile with a "k x n" tile and accumulates the results - into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with - pairs of "bf16"). - - The operation is eventually lowered into the "tdpbf16ps" instruction with - the corresponding tile configuration. - - Example: - - ```mlir - %0 = amx.tile_mulf %a, %b, %c - : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> - ``` - }]; - let arguments = (ins AMXTileF16OrBF16:$lhs, - AMXTileF16OrBF16:$rhs, - AMXTileF32:$acc); - let results = (outs AMXTileF32:$res); - let extraClassDeclaration = [{ - TileType getLhsTileType() { - return ::llvm::cast(getLhs().getType()); - } - TileType getRhsTileType() { - return ::llvm::cast(getRhs().getType()); - } - TileType getTileType() { - return ::llvm::cast(getRes().getType()); - } - - std::string getIntrinsicName() { - std::string intr = "llvm.x86.tdp"; - auto elementType = - getLhsTileType().getElementType(); - intr += elementType.isF16() ? "fp16" : "bf16"; - intr += "ps.internal"; - return intr; - } - SmallVector getIntrinsicOperands( - ::mlir::ArrayRef operands, - const ::mlir::LLVMTypeConverter &typeConverter, - ::mlir::RewriterBase &rewriter); - }]; - let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " - "qualified(type($lhs)) `,` qualified(type($rhs))" - " `,` qualified(type($acc)) "; - let hasVerifier = 1; -} - -def TileMulIOp : AMX_Op<"tile_muli", [Pure, - AMXIntrinsicOpInterface, - AllTypesMatch<["acc", "res"]> - ]> { - let summary = "tile multiplication operation (integer)"; - let description = [{ - Multiplies a "m x k" tile with a "k x n" tile and accumulates the results - into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" - combinations (4 bytes packed into dwords in the columns of both the - source operand tiles; the zero or sign extension is specified with - the attributes and default to sign extended). - - The operation is eventually lowered into one of the "tdpbssd", - "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding - tile configuration. - - Example: - - ```mlir - %0 = amx.tile_muli %a zext, %b zext, %c - : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - ``` - }]; - let arguments = (ins AMXTileI8:$lhs, - AMXTileI8:$rhs, - AMXTileI32:$acc, - UnitAttr:$isZextLhs, - UnitAttr:$isZextRhs - ); - let results = (outs AMXTileI32:$res); - let extraClassDeclaration = [{ - TileType getLhsTileType() { - return ::llvm::cast(getLhs().getType()); - } - TileType getRhsTileType() { - return ::llvm::cast(getRhs().getType()); - } - TileType getTileType() { - return ::llvm::cast(getRes().getType()); - } - - std::string getIntrinsicName() { - std::string intr = "llvm.x86.tdpb"; - intr += getIsZextLhs() ? "u" : "s"; - intr += getIsZextRhs() ? "u" : "s"; - intr += "d.internal"; - return intr; - } - SmallVector getIntrinsicOperands( - ::mlir::ArrayRef operands, - const ::mlir::LLVMTypeConverter &typeConverter, - ::mlir::RewriterBase &rewriter); - }]; - let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " - "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) "; - let hasVerifier = 1; -} - -#endif // AMX diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h deleted file mode 100644 index c79f31d4c994..000000000000 --- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h +++ /dev/null @@ -1,34 +0,0 @@ -//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the Target dialect for AMX in MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_ -#define MLIR_DIALECT_AMX_AMXDIALECT_H_ - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -/// Include the generated interface declarations. -#include "mlir/Dialect/AMX/AMXInterfaces.h.inc" - -#include "mlir/Dialect/AMX/AMXDialect.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/AMX/AMXTypes.h.inc" - -#define GET_OP_CLASSES -#include "mlir/Dialect/AMX/AMX.h.inc" - -#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td b/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td deleted file mode 100644 index 012d1ba7368f..000000000000 --- a/mlir/include/mlir/Dialect/AMX/AMXInterfaces.td +++ /dev/null @@ -1,31 +0,0 @@ -//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines interfaces for the AMX dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef AMX_INTERFACES -#define AMX_INTERFACES - -include "mlir/IR/Interfaces.td" -include "mlir/Dialect/LLVMIR/LLVMInterfaces.td" - -//===----------------------------------------------------------------------===// -// AMX Intrinsic Interface -//===----------------------------------------------------------------------===// - -def AMXIntrinsicOpInterface - : OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> { - let description = [{ - A wrapper interface for operations representing AMX LLVM intrinsics. - }]; - let cppNamespace = "::mlir::amx"; -} - -#endif // AMX_INTERFACES diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt deleted file mode 100644 index f875c78d240c..000000000000 --- a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_mlir_dialect(AMX amx) -add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx) - -add_mlir_interface(AMXInterfaces) -add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h deleted file mode 100644 index 7391ec2ff6b1..000000000000 --- a/mlir/include/mlir/Dialect/AMX/Transforms.h +++ /dev/null @@ -1,33 +0,0 @@ -//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H -#define MLIR_DIALECT_AMX_TRANSFORMS_H - -namespace mlir { - -class LLVMConversionTarget; -class LLVMTypeConverter; -class RewritePatternSet; -class DialectRegistry; - -/// Collect a set of patterns to lower AMX ops to ops that map to LLVM -/// intrinsics. -void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -/// Configure the target to support lowering AMX ops to ops that map to LLVM -/// intrinsics. -void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); - -/// Register LLVM conversion interface for AMX dialect. -void registerConvertAMXToLLVMInterface(DialectRegistry ®istry); - -} // namespace mlir - -#endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index ae9a18046c10..d2505877e2dd 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(Affine) add_subdirectory(AMDGPU) -add_subdirectory(AMX) add_subdirectory(Arith) add_subdirectory(ArmNeon) add_subdirectory(ArmSME) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h index 6d1d63005662..2e76985e92e1 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h @@ -104,10 +104,6 @@ struct SparsifierOptions : public PassPipelineOptions { desc("Allows compiler to assume indices fit in 32-bit if that yields " "faster code"), init(true)}; - PassOptions::Option amx{ - *this, "enable-amx", - desc("Enables the use of AMX dialect while lowering the vector dialect"), - init(false)}; PassOptions::Option armNeon{ *this, "enable-arm-neon", desc("Enables the use of ArmNeon dialect while lowering the vector " @@ -168,7 +164,6 @@ struct SparsifierOptions : public PassPipelineOptions { opts.force32BitVectorIndices = force32BitVectorIndices; opts.armNeon = armNeon; opts.armSVE = armSVE; - opts.amx = amx; opts.x86 = x86; return opts; } diff --git a/mlir/include/mlir/Dialect/X86/Transforms.h b/mlir/include/mlir/Dialect/X86/Transforms.h index 7ab3a0b0b562..2862e83f06f7 100644 --- a/mlir/include/mlir/Dialect/X86/Transforms.h +++ b/mlir/include/mlir/Dialect/X86/Transforms.h @@ -200,13 +200,16 @@ void populateSpecializedTransposeLoweringPatterns( /// Collect a set of patterns to lower X86 ops to ops that map to LLVM /// intrinsics. -void populateX86LegalizeForLLVMExportPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns); +void populateX86LegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); /// Configure the target to support lowering X86 ops to ops that map to /// LLVM intrinsics. void configureX86LegalizeForExportTarget(LLVMConversionTarget &target); +/// Register LLVM conversion interface for X86 dialect. +void registerConvertX86ToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_DIALECT_X86_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/X86/X86.td b/mlir/include/mlir/Dialect/X86/X86.td index 8b5973985a4b..e8965d04c214 100644 --- a/mlir/include/mlir/Dialect/X86/X86.td +++ b/mlir/include/mlir/Dialect/X86/X86.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/X86/X86Interfaces.td" +include "mlir/IR/BuiltinTypes.td" //===----------------------------------------------------------------------===// // X86 dialect definition @@ -25,6 +26,8 @@ include "mlir/Dialect/X86/X86Interfaces.td" def X86_Dialect : Dialect { let name = "x86"; let cppNamespace = "::mlir::x86"; + + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -673,4 +676,385 @@ def CvtPackedOddIndexedToF32Op ::mlir::RewriterBase &rewriter); }]; } + +//===----------------------------------------------------------------------===// +// AMX Tile definition +//===----------------------------------------------------------------------===// + +class AMX_Type traits = []> + : TypeDef { + let mnemonic = "amx." # typeMnemonic; +} + +def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> { + let cppFunctionName = "isValidTileTypeElementType"; +} + +def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> { + let summary = "AMX 2D tile to be used by AMX opertaions."; + + let description = [{ + This type is used to represent values in AMX tile registers. All AMX operations + work on AMX tiles and these tiles cannot be used in other operations directly. + LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and + element type for IR verification and lowering to LLVMIR dialect. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + AMX_TileTypeElementType:$elementType + ); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, "Type":$elementType), [{ + return $_get(elementType.getContext(), shape, elementType); + }]> + ]; + + let extraClassDeclaration = [{ + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Clone this tile type with the given shape and element type. If the + /// provided shape is `std::nullopt`, the current shape of the type is used. + AMXTileType cloneWith(std::optional> shape, + Type elementType) const { + return get(shape.value_or(getShape()), elementType); + } + }]; + + let hasCustomAssemblyFormat = 1; + let skipDefaultBuilders = 1; +} + +def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::x86::AMXTileType>($_self)">, + CPred<[{::llvm::cast<::mlir::x86::AMXTileType>($_self).getRank() == 2}]>]>; + +class AMXTileOf allowedTypes> : + ShapedContainerType; + +def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>; + +def AMXTileF32 : AMXTileOf<[F32]>; + +def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>; + +def AMXTileI32 : AMXTileOf<[I32]>; + +def AMXTileI8 : AMXTileOf<[I8]>; + +//===----------------------------------------------------------------------===// +// AMX Op definitions +//===----------------------------------------------------------------------===// + +class AMX_Op traits = []> + : Op { + let cppNamespace = X86_Dialect.cppNamespace # "::amx"; +} + +//===----------------------------------------------------------------------===// +// AMX Tile Zero +//===----------------------------------------------------------------------===// + +def TileZeroOp : AMX_Op<"tile_zero", [ + X86IntrinsicOpInterface, + MemoryEffects<[MemWrite]> + ]> { + let summary = "tile zero operation"; + let description = [{ + Zeroes the destination tile, with the shape defined by the 2-dim + vector type of the result. + + The operation is eventually lowered into the "tilezero" instruction + with the corresponding tile configuration. + + With the write memory effect, each `x86.amx.tile_zero` operation serves as + a compilation hint to use a separate tile register. + + Example: + + ```mlir + %0 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16> + ``` + }]; + let results = (outs AnyAMXTile:$res); + let extraClassDeclaration = [{ + AMXTileType getTileType() { + return ::llvm::cast(getRes().getType()); + } + + std::string getIntrinsicName() { + return "llvm.x86.tilezero.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; + let assemblyFormat = "attr-dict `:` qualified(type($res))"; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AMX Tile Load +//===----------------------------------------------------------------------===// + +def TileLoadOp : AMX_Op<"tile_load", [ + X86IntrinsicOpInterface, + MemoryEffects<[MemWrite]>, + AttrSizedOperandSegments + ]> { + let summary = "tile load operation"; + let description = [{ + Loads a tile from memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the result. + The tile's rows are populated by reading contiguous elements starting + at the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is loaded using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + tile[row, col] = mem_row[col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tileloadd" instruction + with the corresponding tile configuration. + + With the write memory effect, each `x86.amx.tile_load` operation serves as + a compilation hint to use a separate tile register. + + Example: + + ```mlir + // Tile load from a 2-D memref with implicit stride. + %0 = x86.amx.tile_load %arg0[%c0, %c0] : memref into !x86.amx.tile<16x64xi8> + + // Tile load from a 1-D memref with explicit stride. + %0 = x86.amx.tile_load %arg0[%c0], %stride : memref into !x86.amx.tile<16x64xi8> + ``` + }]; + let arguments = (ins Arg:$base, + Variadic:$indices, + Optional:$stride); + let results = (outs AnyAMXTile:$res); + let builders = [ + OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)> + ]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + AMXTileType getTileType() { + return ::llvm::cast(getRes().getType()); + } + + std::string getIntrinsicName() { + return "llvm.x86.tileloadd64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; + let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict" + "`:` type($base) `into` qualified(type($res))"; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AMX Tile Store +//===----------------------------------------------------------------------===// + +def TileStoreOp : AMX_Op<"tile_store", [ + X86IntrinsicOpInterface, + AttrSizedOperandSegments + ]> { + let summary = "tile store operation"; + let description = [{ + Stores a tile to memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the value. + The tile's rows are written contiguously to the buffer starting at + the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is stored using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + mem_row[col] = tile[row, col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tilestored" instruction + with the corresponding tile configuration. + + Example: + + ```mlir + // Tile store to a 2-D memref with implicit stride. + x86.amx.tile_store %arg1[%c0, %c0], %0 : memref, !x86.amx.tile<16x64xi8> + + // Tile store to a 1-D memref with explicit stride. + x86.amx.tile_store %arg1[%c0], %0, %stride : memref, !x86.amx.tile<16x64xi8> + ``` + }]; + let arguments = (ins Arg:$base, + Variadic:$indices, + AnyAMXTile:$val, + Optional:$stride); + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)> + ]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + AMXTileType getTileType() { + return ::llvm::cast(getVal().getType()); + } + + std::string getIntrinsicName() { + return "llvm.x86.tilestored64.internal"; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; + let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?" + "attr-dict `:` type($base) `,` qualified(type($val))"; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AMX Tile Multiply +//===----------------------------------------------------------------------===// + +def TileMulFOp : AMX_Op<"tile_mulf", [Pure, + X86IntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { + let summary = "tile multiplication operation (floating-point)"; + let description = [{ + Multiplies a "m x k" tile with a "k x n" tile and accumulates the results + into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with + pairs of "bf16"). + + The operation is eventually lowered into the "tdpbf16ps" instruction with + the corresponding tile configuration. + + Example: + + ```mlir + %0 = x86.amx.tile_mulf %a, %b, %c + : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> + ``` + }]; + let arguments = (ins AMXTileF16OrBF16:$lhs, + AMXTileF16OrBF16:$rhs, + AMXTileF32:$acc); + let results = (outs AMXTileF32:$res); + let extraClassDeclaration = [{ + AMXTileType getLhsTileType() { + return ::llvm::cast(getLhs().getType()); + } + AMXTileType getRhsTileType() { + return ::llvm::cast(getRhs().getType()); + } + AMXTileType getTileType() { + return ::llvm::cast(getRes().getType()); + } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdp"; + auto elementType = + getLhsTileType().getElementType(); + intr += elementType.isF16() ? "fp16" : "bf16"; + intr += "ps.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " + "qualified(type($lhs)) `,` qualified(type($rhs))" + " `,` qualified(type($acc)) "; + let hasVerifier = 1; +} + +def TileMulIOp : AMX_Op<"tile_muli", [Pure, + X86IntrinsicOpInterface, + AllTypesMatch<["acc", "res"]> + ]> { + let summary = "tile multiplication operation (integer)"; + let description = [{ + Multiplies a "m x k" tile with a "k x n" tile and accumulates the results + into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" + combinations (4 bytes packed into dwords in the columns of both the + source operand tiles; the zero or sign extension is specified with + the attributes and default to sign extended). + + The operation is eventually lowered into one of the "tdpbssd", + "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding + tile configuration. + + Example: + + ```mlir + %0 = x86.amx.tile_muli %a zext, %b zext, %c + : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + ``` + }]; + let arguments = (ins AMXTileI8:$lhs, + AMXTileI8:$rhs, + AMXTileI32:$acc, + UnitAttr:$isZextLhs, + UnitAttr:$isZextRhs + ); + let results = (outs AMXTileI32:$res); + let extraClassDeclaration = [{ + AMXTileType getLhsTileType() { + return ::llvm::cast(getLhs().getType()); + } + AMXTileType getRhsTileType() { + return ::llvm::cast(getRhs().getType()); + } + AMXTileType getTileType() { + return ::llvm::cast(getRes().getType()); + } + + std::string getIntrinsicName() { + std::string intr = "llvm.x86.tdpb"; + intr += getIsZextLhs() ? "u" : "s"; + intr += getIsZextRhs() ? "u" : "s"; + intr += "d.internal"; + return intr; + } + SmallVector getIntrinsicOperands( + ::mlir::ArrayRef operands, + const ::mlir::LLVMTypeConverter &typeConverter, + ::mlir::RewriterBase &rewriter); + }]; + let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " + "qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) "; + let hasVerifier = 1; +} + #endif // X86_OPS diff --git a/mlir/include/mlir/Dialect/X86/X86Dialect.h b/mlir/include/mlir/Dialect/X86/X86Dialect.h index dbce51e64115..6b1358b31e66 100644 --- a/mlir/include/mlir/Dialect/X86/X86Dialect.h +++ b/mlir/include/mlir/Dialect/X86/X86Dialect.h @@ -29,6 +29,19 @@ #include "mlir/Dialect/X86/X86Dialect.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/X86/X86Types.h.inc" + +namespace mlir { +namespace x86 { +namespace amx { +// Alias to allow access to AMX type through nested namespaces +// analogously to AMX operations. +using TileType = mlir::x86::AMXTileType; +} // namespace amx +} // namespace x86 +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/X86/X86.h.inc" diff --git a/mlir/lib/CAPI/Dialect/AMX.cpp b/mlir/lib/CAPI/Dialect/AMX.cpp deleted file mode 100644 index ed208c9b4b72..000000000000 --- a/mlir/lib/CAPI/Dialect/AMX.cpp +++ /dev/null @@ -1,13 +0,0 @@ -//===- AMX.cpp - C Interface for AMX dialect ------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Dialect/AMX.h" -#include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/AMX/AMXDialect.h" - -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMX, amx, mlir::amx::AMXDialect) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 46b83b3d4f79..551f5a5a3df7 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -26,15 +26,6 @@ add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU MLIRAMDGPUTransforms ) -add_mlir_upstream_c_api_library(MLIRCAPIAMX - AMX.cpp - - PARTIAL_SOURCES_INTENDED - LINK_LIBS PUBLIC - MLIRCAPIIR - MLIRAMXDialect -) - add_mlir_upstream_c_api_library(MLIRCAPIArith Arith.cpp ArithPasses.cpp diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt index 2d4b2b6e9283..2ed864c519cb 100644 --- a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt @@ -8,7 +8,6 @@ add_mlir_conversion_library(MLIRVectorToAMX MLIRConversionPassIncGen LINK_LIBS PUBLIC - MLIRAMXDialect MLIRAffineUtils MLIRArithDialect MLIRLinalgUtils @@ -16,4 +15,5 @@ add_mlir_conversion_library(MLIRVectorToAMX MLIRSCFDialect MLIRTransforms MLIRVectorDialect + MLIRX86Dialect ) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp index 245a3efe98ec..bce67b3e4748 100644 --- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -1,4 +1,4 @@ -//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===// +//===- VectorToAMX.cpp - Convert vector to X86 dialect AMX ops --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,7 +8,6 @@ #include "mlir/Conversion/VectorToAMX/VectorToAMX.h" -#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -16,6 +15,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/X86/X86Dialect.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -197,7 +197,7 @@ static TypedValue collapseLastDim(PatternRewriter &rewriter, static Operation * loadStoreFromTransfer(PatternRewriter &rewriter, VectorTransferOpInterface xferOp, bool isPacked, - TypedValue tileToStore = nullptr) { + TypedValue tileToStore = nullptr) { if (!xferOp || !isa(xferOp)) return nullptr; if (xferOp.hasOutOfBoundsDim() || @@ -267,18 +267,18 @@ loadStoreFromTransfer(PatternRewriter &rewriter, src = collapseLastDim(rewriter, src); int64_t rows = vecShape[0]; int64_t cols = llvm::product_of(vecShape.drop_front()); - auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType()); Value zeroIndex = rewriter.createOrFold(loc, 0); SmallVector tileIndicides(src.getType().getRank(), zeroIndex); Operation *amxTileOp = nullptr; if (isa(xferOp)) { - amxTileOp = - amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides); + amxTileOp = x86::amx::TileLoadOp::create(rewriter, loc, tileType, src, + tileIndicides); } else if (isa(xferOp)) { - amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides, - tileToStore); + amxTileOp = x86::amx::TileStoreOp::create(rewriter, loc, src, tileIndicides, + tileToStore); } else { llvm_unreachable("unsupported vector transfer op"); } @@ -289,10 +289,10 @@ loadStoreFromTransfer(PatternRewriter &rewriter, /// Attempt to create an AMX tile load operation equivalent to the given /// vector transfer `readOp`. /// Returns loaded AMX tile if successful. -static FailureOr> +static FailureOr> loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp, bool isPacked) { - amx::TileLoadOp loadOp = dyn_cast_if_present( + x86::amx::TileLoadOp loadOp = dyn_cast_if_present( loadStoreFromTransfer(rewriter, readOp, isPacked)); if (!loadOp) return failure(); @@ -301,16 +301,16 @@ loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp, /// Attempt to create an AMX tile store operation equivalent to the given /// vector transfer `writeOp`. -static LogicalResult storeFromTransfer(PatternRewriter &rewriter, - vector::TransferWriteOp writeOp, - TypedValue tileToStore) { +static LogicalResult +storeFromTransfer(PatternRewriter &rewriter, vector::TransferWriteOp writeOp, + TypedValue tileToStore) { return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false, tileToStore)); } /// Load vector values to an AMX tile. -static TypedValue loadTile(PatternRewriter &rewriter, - TypedValue vec) { +static TypedValue loadTile(PatternRewriter &rewriter, + TypedValue vec) { Location loc = vec.getLoc(); VectorType vecTy = vec.getType(); @@ -318,7 +318,7 @@ static TypedValue loadTile(PatternRewriter &rewriter, // Try to load tile directly from vector producer's buffer. auto readOp = vec.getDefiningOp(); - FailureOr> tile = + FailureOr> tile = loadFromTransfer(rewriter, readOp, isPacked); if (succeeded(tile)) return *tile; @@ -337,25 +337,25 @@ static TypedValue loadTile(PatternRewriter &rewriter, ArrayRef shape = vecTy.getShape(); int64_t rows = shape[0]; int64_t cols = llvm::product_of(shape.drop_front()); - auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + auto tileType = x86::amx::TileType::get({rows, cols}, vecTy.getElementType()); - return amx::TileLoadOp::create(rewriter, loc, tileType, buf, - {zeroIndex, zeroIndex}); + return x86::amx::TileLoadOp::create(rewriter, loc, tileType, buf, + {zeroIndex, zeroIndex}); } /// Store an AMX tile in a vector. static TypedValue storeTile(PatternRewriter &rewriter, - TypedValue tile) { + TypedValue tile) { Location loc = tile.getLoc(); // Transfer the tile to a vector through an intermediate buffer. - amx::TileType tileTy = tile.getType(); + x86::amx::TileType tileTy = tile.getType(); Value buf = memref::AllocaOp::create( rewriter, loc, MemRefType::get(tileTy.getShape(), tileTy.getElementType())); Value zeroIndex = rewriter.createOrFold(loc, 0); SmallVector indices(2, zeroIndex); - amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); + x86::amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType()); return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {}); @@ -374,19 +374,21 @@ struct ContractionToAMX : public OpRewritePattern { if (failed(validateOperands(rewriter, contractOp))) return failure(); - TypedValue lhsTile = loadTile(rewriter, contractOp.getLhs()); - TypedValue rhsTile = loadTile(rewriter, contractOp.getRhs()); + TypedValue lhsTile = + loadTile(rewriter, contractOp.getLhs()); + TypedValue rhsTile = + loadTile(rewriter, contractOp.getRhs()); auto acc = dyn_cast>(contractOp.getAcc()); assert(acc && "Invalid accumulator type"); - TypedValue accTile = loadTile(rewriter, acc); + TypedValue accTile = loadTile(rewriter, acc); - TypedValue tileMul; + TypedValue tileMul; if (acc.getType().getElementType().isFloat()) { - tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(), - lhsTile, rhsTile, accTile); + tileMul = x86::amx::TileMulFOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); } else { - tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(), - lhsTile, rhsTile, accTile); + tileMul = x86::amx::TileMulIOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); } // If the contraction result is only written back to memory, try to replace diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index 4b1e72788bec..0d700ea65eb4 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -38,8 +38,6 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass MLIRArmNeonTransforms MLIRArmSVEDialect MLIRArmSVETransforms - MLIRAMXDialect - MLIRAMXTransforms MLIRX86Dialect MLIRX86Transforms ) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 19c42ed7e9ed..4cc570435338 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -10,8 +10,6 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmNeon/Transforms.h" @@ -51,8 +49,6 @@ struct ConvertVectorToLLVMPass registry.insert(); if (armSVE) registry.insert(); - if (amx) - registry.insert(); if (x86) registry.insert(); } @@ -136,10 +132,6 @@ void ConvertVectorToLLVMPass::runOnOperation() { configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } - if (amx) { - configureAMXLegalizeForExportTarget(target); - populateAMXLegalizeForLLVMExportPatterns(converter, patterns); - } if (x86) { configureX86LegalizeForExportTarget(target); populateX86LegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt deleted file mode 100644 index 9f57627c321f..000000000000 --- a/mlir/lib/Dialect/AMX/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp deleted file mode 100644 index d9c097c9a3c6..000000000000 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ /dev/null @@ -1,318 +0,0 @@ -//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the AMX dialect and its operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" - -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; - -#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc" - -#include "mlir/Dialect/AMX/AMXDialect.cpp.inc" - -void amx::AMXDialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/AMX/AMXTypes.cpp.inc" - >(); - - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/AMX/AMX.cpp.inc" - >(); -} - -/// Verify that AMX supports the implied tile shape. -static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) { - const unsigned kMaxRows = 16; - const unsigned kBitsPerRow = 64 * 8; - unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); - if (tp.getDimSize(0) > kMaxRows) - return op->emitOpError("bad row height: ") << tp.getDimSize(0); - if (col > kBitsPerRow || col & 0x1f) - return op->emitOpError("bad column width: ") << (col >> 3); - return success(); -} - -/// Verify that AMX supports the multiplication. -static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, - amx::TileType btp, amx::TileType ctp, - unsigned scale) { - unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; - unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; - unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); - if (cm != am || cn != bn || ak != bk) - return op->emitOpError("bad mult shape: ") - << cm << " x " << cn << " x " << ak; - return success(); -} - -/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first -/// dimension directly translates into the number of rows of the tiles. -/// The second dimensions needs to be scaled by the number of bytes. -static SmallVector getTileSizes(Location loc, amx::TileType tType, - RewriterBase &rewriter) { - Type llvmInt16Type = rewriter.getIntegerType(16); - unsigned width = tType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); - auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); - return SmallVector{ - LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr), - LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; -} - -/// Returns stride expressed in number of bytes for the given `elementStride` -/// stride encoded in number of elements of the type `mType`. -static Value computeStrideInBytes(Location loc, MemRefType mType, - Value elementStride, RewriterBase &rewriter) { - Type llvmInt64Type = rewriter.getIntegerType(64); - unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) - .getResult(); -} - -/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer -/// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value inferStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { - assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); - int64_t preLast = mType.getRank() - 2; - Type llvmInt64Type = rewriter.getIntegerType(64); - unsigned width = mType.getElementType().getIntOrFloatBitWidth(); - assert(llvm::isPowerOf2_64(width) && width >= 8); - unsigned bytes = width >> 3; - auto [strides, offset] = mType.getStridesAndOffset(); - if (strides[preLast] == ShapedType::kDynamic) { - // Dynamic stride needs code to compute the stride at runtime. - MemRefDescriptor memrefDescriptor(base); - return computeStrideInBytes( - loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); - } - // Use direct constant for static stride. - auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr) - .getResult(); -} - -LogicalResult amx::TileZeroOp::verify() { - return verifyTileSize(*this, getTileType()); -} - -SmallVector -amx::TileZeroOp::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - return getTileSizes(getLoc(), getTileType(), rewriter); -} - -template || - std::is_same_v>> -static LogicalResult tileTransferVerifier(OpTy op) { - MemRefType memrefTy = op.getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (op.getIndices().size() != rank) - return op.emitOpError("requires ") << rank << " indices"; - - if (failed(verifyTileSize(op, op.getTileType()))) - return failure(); - - // Validate basic buffer properties when the stride is implicit. - if (!op.getStride()) { - if (rank < 2) - return op.emitOpError("requires at least 2D memref"); - SmallVector strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return op.emitOpError("requires memref with unit innermost stride"); - } - - return success(); -} - -void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, - Value base, ValueRange indices) { - build(builder, state, res, base, indices, /*stride=*/nullptr); -} - -LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } - -SmallVector -amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - auto loc = getLoc(); - Adaptor adaptor(operands, *this); - - SmallVector intrinsicOperands; - intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); - intrinsicOperands.push_back( - LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), - adaptor.getBase(), adaptor.getIndices())); - if (Value stride = adaptor.getStride()) - intrinsicOperands.push_back( - computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); - else - intrinsicOperands.push_back( - inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); - - return intrinsicOperands; -} - -void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, - Value base, ValueRange indices, Value val) { - build(builder, state, base, indices, val, /*stride=*/nullptr); -} - -LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } - -SmallVector -amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - auto loc = getLoc(); - Adaptor adaptor(operands, *this); - - SmallVector intrinsicOperands; - intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); - intrinsicOperands.push_back( - LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), - adaptor.getBase(), adaptor.getIndices())); - if (Value stride = adaptor.getStride()) - intrinsicOperands.push_back( - computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); - else - intrinsicOperands.push_back( - inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); - intrinsicOperands.push_back(adaptor.getVal()); - - return intrinsicOperands; -} - -LogicalResult amx::TileMulFOp::verify() { - amx::TileType aType = getLhsTileType(); - amx::TileType bType = getRhsTileType(); - amx::TileType cType = getTileType(); - if (failed(verifyTileSize(*this, aType)) || - failed(verifyTileSize(*this, bType)) || - failed(verifyTileSize(*this, cType)) || - failed(verifyMultShape(*this, aType, bType, cType, 1))) - return failure(); - Type ta = aType.getElementType(); - Type tb = bType.getElementType(); - Type tc = cType.getElementType(); - if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) - return emitOpError("unsupported type combination"); - return success(); -} - -SmallVector -amx::TileMulFOp::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - auto loc = getLoc(); - Adaptor adaptor(operands, *this); - - amx::TileType aType = getLhsTileType(); - amx::TileType bType = getRhsTileType(); - SmallVector tsza = getTileSizes(loc, aType, rewriter); - SmallVector tszb = getTileSizes(loc, bType, rewriter); - - SmallVector intrinsicOperands = {tsza[0], tszb[1], - tsza[1], adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()}; - - return intrinsicOperands; -} - -LogicalResult amx::TileMulIOp::verify() { - amx::TileType aType = getLhsTileType(); - amx::TileType bType = getRhsTileType(); - amx::TileType cType = getTileType(); - if (failed(verifyTileSize(*this, aType)) || - failed(verifyTileSize(*this, bType)) || - failed(verifyTileSize(*this, cType)) || - failed(verifyMultShape(*this, aType, bType, cType, 2))) - return failure(); - Type ta = aType.getElementType(); - Type tb = bType.getElementType(); - Type tc = cType.getElementType(); - if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) - return emitOpError("unsupported type combination"); - return success(); -} - -SmallVector -amx::TileMulIOp::getIntrinsicOperands(ArrayRef operands, - const LLVMTypeConverter &typeConverter, - RewriterBase &rewriter) { - auto loc = getLoc(); - Adaptor adaptor(operands, *this); - - amx::TileType aType = getLhsTileType(); - amx::TileType bType = getRhsTileType(); - SmallVector tsza = getTileSizes(loc, aType, rewriter); - SmallVector tszb = getTileSizes(loc, bType, rewriter); - - SmallVector intrinsicOperands = {tsza[0], tszb[1], - tsza[1], adaptor.getAcc(), - adaptor.getLhs(), adaptor.getRhs()}; - - return intrinsicOperands; -} - -Type amx::TileType::parse(AsmParser &parser) { - if (parser.parseLess()) - return nullptr; - - SmallVector shape; - if (parser.parseDimensionList(shape, false, true)) - return nullptr; - - Type elementType; - if (parser.parseType(elementType)) - return nullptr; - - if (parser.parseGreater()) - return nullptr; - - return TileType::getChecked( - [&] { return parser.emitError(parser.getNameLoc()); }, shape, - elementType); -} - -void amx::TileType::print(AsmPrinter &os) const { - os << "<"; - os.printDimensionList(getShape()); - os << 'x'; - os.printType(getElementType()); - os << '>'; -} - -#define GET_OP_CLASSES -#include "mlir/Dialect/AMX/AMX.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/AMX/AMXTypes.cpp.inc" diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt deleted file mode 100644 index b6e2759843d5..000000000000 --- a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_mlir_dialect_library(MLIRAMXDialect - AMXDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX - - DEPENDS - MLIRAMXIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMCommonConversion - MLIRLLVMDialect - MLIRSideEffectInterfaces - ) diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt deleted file mode 100644 index e827bc475e93..000000000000 --- a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -add_mlir_dialect_library(MLIRAMXTransforms - LegalizeForLLVMExport.cpp - - LINK_LIBS PUBLIC - MLIRAMXDialect - MLIRIR - MLIRLLVMCommonConversion - MLIRLLVMDialect - ) diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp deleted file mode 100644 index 6483af222e91..000000000000 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ /dev/null @@ -1,70 +0,0 @@ -//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/AMX/Transforms.h" - -#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/AMX/AMXDialect.h" -#include "mlir/IR/PatternMatch.h" - -using namespace mlir; -using namespace mlir::amx; - -namespace { - -/// Generic one-to-one conversion of simply mappable operations into calls -/// to their respective LLVM intrinsics. -struct AMXIntrinsicOpConversion - : public ConvertOpInterfaceToLLVMPattern { - using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern; - - LogicalResult - matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - const LLVMTypeConverter &typeConverter = *getTypeConverter(); - return LLVM::detail::intrinsicRewrite( - op, rewriter.getStringAttr(op.getIntrinsicName()), - op.getIntrinsicOperands(operands, typeConverter, rewriter), - typeConverter, rewriter); - } -}; - -} // namespace - -void mlir::populateAMXLegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); - converter.addConversion([&](amx::TileType type) { - return LLVM::LLVMX86AMXType::get(&converter.getContext()); - }); -} - -void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { - target.addIllegalDialect(); -} - -namespace { -/// Implement the interface to convert AMX to LLVM. -struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { - using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; - - void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) const final { - populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); - } -}; -} // namespace - -void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { - dialect->addInterfaces(); - }); -} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 65dada6ac4bf..66f68c369f81 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(Affine) add_subdirectory(AMDGPU) -add_subdirectory(AMX) add_subdirectory(Arith) add_subdirectory(ArmNeon) add_subdirectory(ArmSME) diff --git a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp index e1714bdc8dc1..47ee5d272a89 100644 --- a/mlir/lib/Dialect/X86/IR/X86Dialect.cpp +++ b/mlir/lib/Dialect/X86/IR/X86Dialect.cpp @@ -11,10 +11,16 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/X86/X86Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + using namespace mlir; #include "mlir/Dialect/X86/X86Interfaces.cpp.inc" @@ -22,6 +28,11 @@ using namespace mlir; #include "mlir/Dialect/X86/X86Dialect.cpp.inc" void x86::X86Dialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/X86/X86Types.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "mlir/Dialect/X86/X86.cpp.inc" @@ -107,5 +118,279 @@ SmallVector x86::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( typeConverter, rewriter)}; } +/// Verify that AMX supports the implied tile shape. +static LogicalResult verifyTileSize(Operation *op, x86::amx::TileType tp) { + const unsigned kMaxRows = 16; + const unsigned kBitsPerRow = 64 * 8; + unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); + if (tp.getDimSize(0) > kMaxRows) + return op->emitOpError("bad row height: ") << tp.getDimSize(0); + if (col > kBitsPerRow || col & 0x1f) + return op->emitOpError("bad column width: ") << (col >> 3); + return success(); +} + +/// Verify that AMX supports the multiplication. +static LogicalResult verifyMultShape(Operation *op, x86::amx::TileType atp, + x86::amx::TileType btp, + x86::amx::TileType ctp, unsigned scale) { + unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; + unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; + unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); + if (cm != am || cn != bn || ak != bk) + return op->emitOpError("bad mult shape: ") + << cm << " x " << cn << " x " << ak; + return success(); +} + +/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first +/// dimension directly translates into the number of rows of the tiles. +/// The second dimensions needs to be scaled by the number of bytes. +static SmallVector getTileSizes(Location loc, x86::amx::TileType tType, + RewriterBase &rewriter) { + Type llvmInt16Type = rewriter.getIntegerType(16); + unsigned width = tType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); + auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); + return SmallVector{ + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr), + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; +} + +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + +/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer +/// shape may "envelop" the actual tile shape, and may be dynamically sized. +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { + assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); + int64_t preLast = mType.getRank() - 2; + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned width = mType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto [strides, offset] = mType.getStridesAndOffset(); + if (strides[preLast] == ShapedType::kDynamic) { + // Dynamic stride needs code to compute the stride at runtime. + MemRefDescriptor memrefDescriptor(base); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); + } + // Use direct constant for static stride. + auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); + return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr) + .getResult(); +} + +LogicalResult x86::amx::TileZeroOp::verify() { + return verifyTileSize(*this, getTileType()); +} + +SmallVector x86::amx::TileZeroOp::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + return getTileSizes(getLoc(), getTileType(), rewriter); +} + +template || + std::is_same_v>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); + unsigned rank = memrefTy.getRank(); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void x86::amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, + Type res, Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); +} + +LogicalResult x86::amx::TileLoadOp::verify() { + return tileTransferVerifier(*this); +} + +SmallVector x86::amx::TileLoadOp::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + + return intrinsicOperands; +} + +void x86::amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); +} + +LogicalResult x86::amx::TileStoreOp::verify() { + return tileTransferVerifier(*this); +} + +SmallVector x86::amx::TileStoreOp::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + SmallVector intrinsicOperands; + intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); + intrinsicOperands.push_back( + LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), + adaptor.getBase(), adaptor.getIndices())); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + intrinsicOperands.push_back(adaptor.getVal()); + + return intrinsicOperands; +} + +LogicalResult x86::amx::TileMulFOp::verify() { + x86::amx::TileType aType = getLhsTileType(); + x86::amx::TileType bType = getRhsTileType(); + x86::amx::TileType cType = getTileType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 1))) + return failure(); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) + return emitOpError("unsupported type combination"); + return success(); +} + +SmallVector x86::amx::TileMulFOp::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + x86::amx::TileType aType = getLhsTileType(); + x86::amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + +LogicalResult x86::amx::TileMulIOp::verify() { + x86::amx::TileType aType = getLhsTileType(); + x86::amx::TileType bType = getRhsTileType(); + x86::amx::TileType cType = getTileType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 2))) + return failure(); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) + return emitOpError("unsupported type combination"); + return success(); +} + +SmallVector x86::amx::TileMulIOp::getIntrinsicOperands( + ArrayRef operands, const LLVMTypeConverter &typeConverter, + RewriterBase &rewriter) { + auto loc = getLoc(); + Adaptor adaptor(operands, *this); + + x86::amx::TileType aType = getLhsTileType(); + x86::amx::TileType bType = getRhsTileType(); + SmallVector tsza = getTileSizes(loc, aType, rewriter); + SmallVector tszb = getTileSizes(loc, bType, rewriter); + + SmallVector intrinsicOperands = {tsza[0], tszb[1], + tsza[1], adaptor.getAcc(), + adaptor.getLhs(), adaptor.getRhs()}; + + return intrinsicOperands; +} + +Type x86::amx::TileType::parse(AsmParser &parser) { + if (parser.parseLess()) + return nullptr; + + SmallVector shape; + if (parser.parseDimensionList(shape, false, true)) + return nullptr; + + Type elementType; + if (parser.parseType(elementType)) + return nullptr; + + if (parser.parseGreater()) + return nullptr; + + return AMXTileType::getChecked( + [&] { return parser.emitError(parser.getNameLoc()); }, shape, + elementType); +} + +void x86::amx::TileType::print(AsmPrinter &os) const { + os << "<"; + os.printDimensionList(getShape()); + os << 'x'; + os.printType(getElementType()); + os << '>'; +} + #define GET_OP_CLASSES #include "mlir/Dialect/X86/X86.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/X86/X86Types.cpp.inc" diff --git a/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp index c07559dc295f..8907b5f482e9 100644 --- a/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/X86/Transforms.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/X86/X86Dialect.h" @@ -39,10 +40,32 @@ struct X86IntrinsicOpConversion /// Populate the given list with patterns that convert from X86 to LLVM. void mlir::populateX86LegalizeForLLVMExportPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(converter); + converter.addConversion([&](x86::amx::TileType type) { + return LLVM::LLVMX86AMXType::get(&converter.getContext()); + }); } void mlir::configureX86LegalizeForExportTarget(LLVMConversionTarget &target) { target.addIllegalDialect(); } + +namespace { +/// Implement the interface to convert X86 to LLVM. +struct X86ToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateX86LegalizeForLLVMExportPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertX86ToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, x86::X86Dialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp index 10944f72aa3c..ea5698f39c0b 100644 --- a/mlir/lib/RegisterAllDialects.cpp +++ b/mlir/lib/RegisterAllDialects.cpp @@ -14,7 +14,6 @@ #include "mlir/InitAllDialects.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -111,7 +110,6 @@ void mlir::registerAllDialects(DialectRegistry ®istry) { registry.insert, %B: vector<8x16x2xf16>, // CHECK: vector.transfer_write %[[A]], %[[A_BUF]] // CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]] // CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16> -// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]] +// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_BUF_2D]] /// Load B vector into an AMX tile // CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16> // CHECK: vector.transfer_write %[[B]], %[[B_BUF]] // CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]] // CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16> -// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]] +// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_BUF_2D]] /// Load C vector into an AMX tile // CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32> // CHECK: vector.transfer_write %[[C]], %[[C_BUF]] -// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]] +// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_BUF]] /// Perform tile multiplication -// CHECK: %[[RES:.+]] = amx.tile_mulf +// CHECK: %[[RES:.+]] = x86.amx.tile_mulf // CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]] /// Load the result back into a vector // CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32> -// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]] +// CHECK: x86.amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]] // CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]] // CHECK: return %[[RES_VEC]] @@ -75,9 +75,9 @@ func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>, } // CHECK-LABEL: @contract_vnni_bf16( -// CHECK-COUNT-3: amx.tile_load -// CHECK: amx.tile_mulf -// CHECK: amx.tile_store +// CHECK-COUNT-3: x86.amx.tile_load +// CHECK: x86.amx.tile_mulf +// CHECK: x86.amx.tile_store // ----- @@ -95,9 +95,9 @@ func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>, } // CHECK-LABEL: @contract_vnni_i8( -// CHECK-COUNT-3: amx.tile_load -// CHECK: amx.tile_muli -// CHECK: amx.tile_store +// CHECK-COUNT-3: x86.amx.tile_load +// CHECK: x86.amx.tile_muli +// CHECK: x86.amx.tile_store // ----- @@ -115,9 +115,9 @@ func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4x } // CHECK-LABEL: @contract_shuffled_iterators( -// CHECK-COUNT-3: amx.tile_load -// CHECK: amx.tile_muli -// CHECK: amx.tile_store +// CHECK-COUNT-3: x86.amx.tile_load +// CHECK: x86.amx.tile_muli +// CHECK: x86.amx.tile_store // ----- diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir index 8fab4cf1f7ed..120f13bd5d87 100644 --- a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir +++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir @@ -38,7 +38,7 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>, // CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}} // CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]] // CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16 -// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]] +// CHECK: %[[A_TILE:.+]] = x86.amx.tile_load %[[A_PACKED_DIM_COLLAPSE]] // CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}} // CHECK-NOT: vector.transfer_read %[[A]] @@ -47,25 +47,25 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>, // CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}} // CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]] // CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16 -// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]] +// CHECK: %[[B_TILE:.+]] = x86.amx.tile_load %[[B_PACKED_DIM_COLLAPSE]] // CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}} // CHECK-NOT: vector.transfer_read %[[B]] /// Load C into an AMX tile // CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]] // CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}} -// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]] +// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]] // CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}} // CHECK-NOT: vector.transfer_read %[[C]] /// Perform tile multiplication -// CHECK: %[[RES:.+]] = amx.tile_mulf +// CHECK: %[[RES:.+]] = x86.amx.tile_mulf // CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]] /// Store the result back // CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]] // CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}} -// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]] +// CHECK: x86.amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]] // CHECK-NOT: vector.transfer_write{{.*}}%[[C]] // ----- @@ -130,17 +130,17 @@ func.func @transfer_read_multiple_users(%C: memref<64x64xf32>, /// Load to AMX tile directly from buffer. // CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]] -// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]] +// CHECK: %[[C_TILE:.+]] = x86.amx.tile_load %[[C_SUBVIEW]] /// Vector read remains to load data for the other non-AMX consumer. // CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]] /// Contraction uses the directly loaded tile. -// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]] +// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf{{.*}}%[[C_TILE]] /// Consumer uses original C value and the updated one after contraction. // CHECK: %[[RES_BUF:.+]] = memref.alloca -// CHECK: amx.tile_store %[[RES_BUF]] +// CHECK: x86.amx.tile_store %[[RES_BUF]] // CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]] // CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]] @@ -168,7 +168,7 @@ func.func @negative_contract_multiple_users(%C: memref<64x64xf32>, // CHECK-LABEL: @negative_contract_multiple_users( // CHECK-SAME: %[[C:.+]]: memref<64x64xf32> -// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf +// CHECK: %[[TILE_MUL:.+]] = x86.amx.tile_mulf // CHECK: vector.transfer_write{{.*}}%[[C]] // ----- diff --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir index 8070aee19f94..e457e318b078 100644 --- a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir +++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir @@ -18,7 +18,6 @@ // CHECK: builtin.module( // CHECK-SAME: convert-vector-to-llvm{ -// CHECK-SAME: enable-amx={{[aA-zZ0-9]+}} // CHECK-SAME: enable-arm-neon={{[aA-zZ0-9]+}} // CHECK-SAME: enable-arm-sve={{[aA-zZ0-9]+}} // CHECK-SAME: enable-x86={{[aA-zZ0-9]+}} diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir deleted file mode 100644 index 5de9b3f82a86..000000000000 --- a/mlir/test/Dialect/AMX/invalid.mlir +++ /dev/null @@ -1,158 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics - -func.func @tile_row_height() { - // expected-error@+1 {{'amx.tile_zero' op bad row height: 17}} - %0 = amx.tile_zero : !amx.tile<17x16xbf16> - return -} - -// ----- - -func.func @tile_col_width() { - // expected-error@+1 {{'amx.tile_zero' op bad column width: 65}} - %0 = amx.tile_zero : !amx.tile<16x65xi8> - return -} - -// ----- - -func.func @tile_element_type() { - // expected-error@+1 {{failed to verify 'elementType'}} - %0 = amx.tile_zero : !amx.tile<8x8xi16> - return -} - -// ----- - -func.func @tile_rank() { - // expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}} - %0 = amx.tile_zero : !amx.tile<32xi8> - return -} - -// ----- - -func.func @tile_col_4_byte_multiple() { - // expected-error@+1 {{'amx.tile_zero' op bad column width: 5}} - %0 = amx.tile_zero : !amx.tile<16x5xi8> - return -} - -// ----- - -func.func @load_base_tile_size(%arg0: memref) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_load' op bad column width: 68}} - %1 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x17xf32> - return -} - -// ----- - -func.func @store_base_tile_size(%arg0: memref, %arg1: !amx.tile<16x17xf32>) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_store' op bad column width: 68}} - amx.tile_store %arg0[%0, %0], %arg1 : memref, !amx.tile<16x17xf32> - return -} - -// ----- - -func.func @load_base_index_size(%arg0: memref) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_load' op requires 2 indices}} - %1 = amx.tile_load %arg0[%0] : memref into !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @store_base_index_size(%arg0: memref, %arg1: !amx.tile<16x16xf32>) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_store' op requires 2 indices}} - amx.tile_store %arg0[%0], %arg1 : memref, !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @load_base_rank(%arg0: memref) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_load' op requires at least 2D memref}} - %1 = amx.tile_load %arg0[%0] : memref into !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @store_base_rank(%arg0: memref, %arg1: !amx.tile<16x16xf32>) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_store' op requires at least 2D memref}} - amx.tile_store %arg0[%0], %arg1 : memref, !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @load_base_non_unit_stride(%arg0: memref>) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_load' op requires memref with unit innermost stride}} - %1 = amx.tile_load %arg0[%0, %0] - : memref> into !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @store_base_non_unit_stride(%arg0: memref>, - %arg1: !amx.tile<16x16xf32>) { - %0 = arith.constant 0 : index - // expected-error@+1 {{'amx.tile_store' op requires memref with unit innermost stride}} - amx.tile_store %arg0[%0, %0], %arg1 - : memref>, !amx.tile<16x16xf32> - return -} - -// ----- - -func.func @mulf_shape() { - %0 = amx.tile_zero : !amx.tile<8x8xbf16> - %1 = amx.tile_zero : !amx.tile<8x8xbf16> - %2 = amx.tile_zero : !amx.tile<4x4xf32> - // expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}} - %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32> - return -} - -// ----- - -func.func @mulf_type_combination() { - %0 = amx.tile_zero : !amx.tile<8x8xbf16> - %1 = amx.tile_zero : !amx.tile<4x8xf16> - %2 = amx.tile_zero : !amx.tile<8x4xf32> - // expected-error@+1 {{'amx.tile_mulf' op unsupported type combination}} - %3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32> - return -} - -// ----- - -func.func @muli_shape() { - %0 = amx.tile_zero : !amx.tile<8x8xi8> - %1 = amx.tile_zero : !amx.tile<8x8xi8> - %2 = amx.tile_zero : !amx.tile<4x4xi32> - // expected-error@+1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}} - %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32> - return -} - -// ----- - -func.func @muli_type_combination() { - %0 = amx.tile_zero : !amx.tile<8x16xi8> - %1 = amx.tile_zero : !amx.tile<8x16xi32> - %2 = amx.tile_zero : !amx.tile<2x2xi32> - // expected-error@+1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}} - %3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32> - return -} diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir deleted file mode 100644 index 3d0f276df6a2..000000000000 --- a/mlir/test/Dialect/AMX/roundtrip.mlir +++ /dev/null @@ -1,77 +0,0 @@ -// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s - -// CHECK-LABEL: tloadstore -// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} : -// CHECK-SAME: memref into !amx.tile<16x32xbf16> -// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : -// CHECK-SAME: memref into !amx.tile<16x32xbf16> -// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : -// CHECK-SAME: memref> into !amx.tile<16x32xbf16> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} : -// CHECK-SAME: memref, !amx.tile<16x32xbf16> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} : -// CHECK-SAME: memref, !amx.tile<16x32xbf16> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] : -// CHECK-SAME: memref>, !amx.tile<16x32xbf16> -func.func @tloadstore(%stride: index, - %arg0: memref, - %arg1: memref, - %arg2: memref>) { - %0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %1 = amx.tile_load %arg0[%0], %stride : memref into !amx.tile<16x32xbf16> - %2 = amx.tile_load %arg1[%0, %0], %stride : memref into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg2[%0, %0] : memref> into !amx.tile<16x32xbf16> - amx.tile_store %arg0[%0], %3, %stride : memref, !amx.tile<16x32xbf16> - amx.tile_store %arg1[%0, %0], %1, %stride : memref, !amx.tile<16x32xbf16> - amx.tile_store %arg2[%0, %0], %2 : memref>, !amx.tile<16x32xbf16> - return -} - -// CHECK-LABEL: tzero -// CHECK: amx.tile_zero : !amx.tile<16x16xbf16> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, !amx.tile<16x16xbf16> -func.func @tzero(%arg0: memref) { - %0 = arith.constant 0 : index - %1 = amx.tile_zero : !amx.tile<16x16xbf16> - amx.tile_store %arg0[%0, %0], %1 : memref, !amx.tile<16x16xbf16> - return -} - -// CHECK-LABEL: tmulf -// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !amx.tile<16x32xbf16> -// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !amx.tile<16x16xf32> -// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, !amx.tile<16x16xf32> -func.func @tmulf(%arg0: memref, %arg1: memref) { - %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x32xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref into !amx.tile<16x16xf32> - %3 = amx.tile_mulf %1, %1, %2 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> - amx.tile_store %arg1[%0, %0], %3 : memref, !amx.tile<16x16xf32> - return -} - -// CHECK-LABEL: tmuli -// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !amx.tile<16x64xi8> -// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !amx.tile<16x64xi8> -// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !amx.tile<16x16xi32> -// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> -// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, !amx.tile<16x16xi32> -// Verify the parsing/printing of the sign-extension annotation. -// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}} -// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}} -// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}} -func.func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { - %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x64xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref into !amx.tile<16x64xi8> - %3 = amx.tile_load %arg2[%0, %0] : memref into !amx.tile<16x16xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - amx.tile_store %arg2[%0, %0], %4 : memref, !amx.tile<16x16xi32> - // Verify the various `zext` combinations. - %5 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - %7 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - return -} diff --git a/mlir/test/Dialect/AMX/side-effects.mlir b/mlir/test/Dialect/AMX/side-effects.mlir deleted file mode 100644 index 22c76d98c699..000000000000 --- a/mlir/test/Dialect/AMX/side-effects.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-amx" | FileCheck %s - -// With inclusion of memory side-effects, it is expected CSE not to fold multiple -// "tileload" and "tilezero". -// CHECK-LABEL: do_not_fold_tiles( -// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" -// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" -func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c16 = arith.constant 16 : index - %alloca = memref.alloca() : memref<16x32xf32> - %0 = amx.tile_zero : !amx.tile<16x16xf32> - %1 = amx.tile_zero : !amx.tile<16x16xf32> - %2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!amx.tile<16x16xf32>, !amx.tile<16x16xf32>) { - %3 = amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16> - %4 = amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !amx.tile<16x32xbf16> - %5 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16> - %6 = amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !amx.tile<16x32xbf16> - %7 = amx.tile_mulf %3, %5, %arg3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> - %8 = amx.tile_mulf %4, %6, %arg4 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> - scf.yield %7, %8 : !amx.tile<16x16xf32>, !amx.tile<16x16xf32> - } - amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !amx.tile<16x16xf32> - amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !amx.tile<16x16xf32> - return %alloca : memref<16x32xf32> -} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index bb868afe08cb..6d6ebe0b5e60 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -635,10 +635,10 @@ func.func @invalid_indexing_maps_placement_matmul(%lhs: tensor<4x1xf32>, %rhs: t // ----- -func.func @invalid_type_matmul(%arg0 : !amx.tile<16x16xbf16>) +func.func @invalid_type_matmul(%arg0 : !x86.amx.tile<16x16xbf16>) { - // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}} - %0 = linalg.matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16> + // expected-error @below {{custom op 'linalg.matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}} + %0 = linalg.matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16> return } @@ -1582,10 +1582,10 @@ func.func @invalid_C_map_result_dim_batch_matmul(%arg0: memref, %arg1 // ----- -func.func @invalid_type_batch_matmul(%arg0 : !amx.tile<16x16xbf16>) +func.func @invalid_type_batch_matmul(%arg0 : !x86.amx.tile<16x16xbf16>) { - // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}} - %0 = linalg.batch_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16> + // expected-error @below {{custom op 'linalg.batch_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}} + %0 = linalg.batch_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16> return } @@ -1790,10 +1790,10 @@ func.func @invalid_C_map_result_dim(%A: memref, %B: memref // ----- -func.func @batch_reduce_matmul_invalid_type(%arg0 : !amx.tile<16x16xbf16>) +func.func @batch_reduce_matmul_invalid_type(%arg0 : !x86.amx.tile<16x16xbf16>) { - // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!amx.tile<16x16xbf16>' and '!amx.tile<16x16xbf16>'}} - %0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !amx.tile<16x16xbf16>, !amx.tile<16x16xbf16>) outs(%arg0 : !amx.tile<16x16xbf16>) -> !amx.tile<16x16xbf16> + // expected-error @below {{custom op 'linalg.batch_reduce_matmul' Cannot build binary Linalg operation: expects allComplex, allFloatingPoint, or allInteger, got '!x86.amx.tile<16x16xbf16>' and '!x86.amx.tile<16x16xbf16>'}} + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg0 : !x86.amx.tile<16x16xbf16>, !x86.amx.tile<16x16xbf16>) outs(%arg0 : !x86.amx.tile<16x16xbf16>) -> !x86.amx.tile<16x16xbf16> return } diff --git a/mlir/test/Dialect/X86/AMX/invalid.mlir b/mlir/test/Dialect/X86/AMX/invalid.mlir new file mode 100644 index 000000000000..25033090808b --- /dev/null +++ b/mlir/test/Dialect/X86/AMX/invalid.mlir @@ -0,0 +1,158 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func.func @tile_row_height() { + // expected-error@+1 {{'x86.amx.tile_zero' op bad row height: 17}} + %0 = x86.amx.tile_zero : !x86.amx.tile<17x16xbf16> + return +} + +// ----- + +func.func @tile_col_width() { + // expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 65}} + %0 = x86.amx.tile_zero : !x86.amx.tile<16x65xi8> + return +} + +// ----- + +func.func @tile_element_type() { + // expected-error@+1 {{failed to verify 'elementType'}} + %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi16> + return +} + +// ----- + +func.func @tile_rank() { + // expected-error@+1 {{'x86.amx.tile_zero' op result #0 must be tile of}} + %0 = x86.amx.tile_zero : !x86.amx.tile<32xi8> + return +} + +// ----- + +func.func @tile_col_4_byte_multiple() { + // expected-error@+1 {{'x86.amx.tile_zero' op bad column width: 5}} + %0 = x86.amx.tile_zero : !x86.amx.tile<16x5xi8> + return +} + +// ----- + +func.func @load_base_tile_size(%arg0: memref) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_load' op bad column width: 68}} + %1 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x17xf32> + return +} + +// ----- + +func.func @store_base_tile_size(%arg0: memref, %arg1: !x86.amx.tile<16x17xf32>) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_store' op bad column width: 68}} + x86.amx.tile_store %arg0[%0, %0], %arg1 : memref, !x86.amx.tile<16x17xf32> + return +} + +// ----- + +func.func @load_base_index_size(%arg0: memref) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_load' op requires 2 indices}} + %1 = x86.amx.tile_load %arg0[%0] : memref into !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @store_base_index_size(%arg0: memref, %arg1: !x86.amx.tile<16x16xf32>) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_store' op requires 2 indices}} + x86.amx.tile_store %arg0[%0], %arg1 : memref, !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @load_base_rank(%arg0: memref) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_load' op requires at least 2D memref}} + %1 = x86.amx.tile_load %arg0[%0] : memref into !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @store_base_rank(%arg0: memref, %arg1: !x86.amx.tile<16x16xf32>) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_store' op requires at least 2D memref}} + x86.amx.tile_store %arg0[%0], %arg1 : memref, !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @load_base_non_unit_stride(%arg0: memref>) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_load' op requires memref with unit innermost stride}} + %1 = x86.amx.tile_load %arg0[%0, %0] + : memref> into !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @store_base_non_unit_stride(%arg0: memref>, + %arg1: !x86.amx.tile<16x16xf32>) { + %0 = arith.constant 0 : index + // expected-error@+1 {{'x86.amx.tile_store' op requires memref with unit innermost stride}} + x86.amx.tile_store %arg0[%0, %0], %arg1 + : memref>, !x86.amx.tile<16x16xf32> + return +} + +// ----- + +func.func @mulf_shape() { + %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16> + %1 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16> + %2 = x86.amx.tile_zero : !x86.amx.tile<4x4xf32> + // expected-error@+1 {{'x86.amx.tile_mulf' op bad mult shape: 4 x 4 x 4}} + %3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x4xf32> + return +} + +// ----- + +func.func @mulf_type_combination() { + %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xbf16> + %1 = x86.amx.tile_zero : !x86.amx.tile<4x8xf16> + %2 = x86.amx.tile_zero : !x86.amx.tile<8x4xf32> + // expected-error@+1 {{'x86.amx.tile_mulf' op unsupported type combination}} + %3 = x86.amx.tile_mulf %0, %1, %2 : !x86.amx.tile<8x8xbf16>, !x86.amx.tile<4x8xf16>, !x86.amx.tile<8x4xf32> + return +} + +// ----- + +func.func @muli_shape() { + %0 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8> + %1 = x86.amx.tile_zero : !x86.amx.tile<8x8xi8> + %2 = x86.amx.tile_zero : !x86.amx.tile<4x4xi32> + // expected-error@+1 {{'x86.amx.tile_muli' op bad mult shape: 4 x 4 x 2}} + %3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x8xi8>, !x86.amx.tile<8x8xi8>, !x86.amx.tile<4x4xi32> + return +} + +// ----- + +func.func @muli_type_combination() { + %0 = x86.amx.tile_zero : !x86.amx.tile<8x16xi8> + %1 = x86.amx.tile_zero : !x86.amx.tile<8x16xi32> + %2 = x86.amx.tile_zero : !x86.amx.tile<2x2xi32> + // expected-error@+1 {{'x86.amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}} + %3 = x86.amx.tile_muli %0, %1, %2 : !x86.amx.tile<8x16xi8>, !x86.amx.tile<8x16xi32>, !x86.amx.tile<2x2xi32> + return +} diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir similarity index 64% rename from mlir/test/Dialect/AMX/legalize-for-llvm.mlir rename to mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir index a109f42e9dea..eb12e20b699b 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86" | mlir-opt | FileCheck %s // CHECK-LABEL: muli( // CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" @@ -14,17 +14,17 @@ // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @muli(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index - %1 = amx.tile_zero : !amx.tile<16x64xi8> - %2 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x64xi8> - %3 = amx.tile_load %arg1[%0, %0] : memref into !amx.tile<16x16xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - amx.tile_store %arg1[%0, %0], %4 : memref, !amx.tile<16x16xi32> - %5 = amx.tile_muli %1, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - amx.tile_store %arg1[%0, %0], %5 : memref, !amx.tile<16x16xi32> - %6 = amx.tile_muli %1 zext, %2, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - amx.tile_store %arg1[%0, %0], %6 : memref, !amx.tile<16x16xi32> - %7 = amx.tile_muli %1, %2 zext, %3 : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - amx.tile_store %arg1[%0, %0], %7 : memref, !amx.tile<16x16xi32> + %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8> + %2 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x64xi8> + %3 = x86.amx.tile_load %arg1[%0, %0] : memref into !x86.amx.tile<16x16xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %arg1[%0, %0], %4 : memref, !x86.amx.tile<16x16xi32> + %5 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %arg1[%0, %0], %5 : memref, !x86.amx.tile<16x16xi32> + %6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %arg1[%0, %0], %6 : memref, !x86.amx.tile<16x16xi32> + %7 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %arg1[%0, %0], %7 : memref, !x86.amx.tile<16x16xi32> return } @@ -36,11 +36,11 @@ func.func @muli(%arg0: memref, %arg1: memref) { // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulbf16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index - %1 = amx.tile_zero : !amx.tile<16x32xbf16> - %2 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg1[%0, %0] : memref into !amx.tile<16x16xf32> - %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> - amx.tile_store %arg1[%0, %0], %4 : memref, !amx.tile<16x16xf32> + %1 = x86.amx.tile_zero : !x86.amx.tile<16x32xbf16> + %2 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x32xbf16> + %3 = x86.amx.tile_load %arg1[%0, %0] : memref into !x86.amx.tile<16x16xf32> + %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> + x86.amx.tile_store %arg1[%0, %0], %4 : memref, !x86.amx.tile<16x16xf32> return } @@ -52,11 +52,11 @@ func.func @mulbf16(%arg0: memref, %arg1: memref) { // CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal" func.func @mulfp16(%arg0: memref, %arg1: memref) { %0 = arith.constant 0 : index - %1 = amx.tile_zero : !amx.tile<16x32xf16> - %2 = amx.tile_load %arg0[%0, %0] : memref into !amx.tile<16x32xf16> - %3 = amx.tile_load %arg1[%0, %0] : memref into !amx.tile<16x16xf32> - %4 = amx.tile_mulf %1, %2, %3 : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32> - amx.tile_store %arg1[%0, %0], %4 : memref, !amx.tile<16x16xf32> + %1 = x86.amx.tile_zero : !x86.amx.tile<16x32xf16> + %2 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x32xf16> + %3 = x86.amx.tile_load %arg1[%0, %0] : memref into !x86.amx.tile<16x16xf32> + %4 = x86.amx.tile_mulf %1, %2, %3 : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32> + x86.amx.tile_store %arg1[%0, %0], %4 : memref, !x86.amx.tile<16x16xf32> return } @@ -84,12 +84,12 @@ func.func @strides_implicit(%arg0: memref<16x32xi8>, %arg1: memref<32x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xf32, strided<[?, 1]>>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32> - amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8> - amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> - amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !x86.amx.tile<16x32xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !x86.amx.tile<16x32xbf16> + %3 = x86.amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !x86.amx.tile<16x16xf32> + x86.amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !x86.amx.tile<16x32xi8> + x86.amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !x86.amx.tile<16x32xbf16> + x86.amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !x86.amx.tile<16x16xf32> return } @@ -123,11 +123,11 @@ func.func @strides_explicit(%stride: index, %arg2: memref<32x32xf32, strided<[64, 1]>>) { %0 = arith.constant 0 : index %c64 = arith.constant 64 : index - %1 = amx.tile_load %arg0[%0], %stride : memref into !amx.tile<16x32xi8> - %2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32> - amx.tile_store %arg0[%0], %1, %stride : memref, !amx.tile<16x32xi8> - amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16> - amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32> + %1 = x86.amx.tile_load %arg0[%0], %stride : memref into !x86.amx.tile<16x32xi8> + %2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !x86.amx.tile<16x32xbf16> + %3 = x86.amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !x86.amx.tile<16x16xf32> + x86.amx.tile_store %arg0[%0], %1, %stride : memref, !x86.amx.tile<16x32xi8> + x86.amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !x86.amx.tile<16x32xbf16> + x86.amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !x86.amx.tile<16x16xf32> return } diff --git a/mlir/test/Dialect/X86/AMX/roundtrip.mlir b/mlir/test/Dialect/X86/AMX/roundtrip.mlir new file mode 100644 index 000000000000..300c3aa054a7 --- /dev/null +++ b/mlir/test/Dialect/X86/AMX/roundtrip.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: tloadstore +// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} : +// CHECK-SAME: memref into !x86.amx.tile<16x32xbf16> +// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : +// CHECK-SAME: memref into !x86.amx.tile<16x32xbf16> +// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref> into !x86.amx.tile<16x32xbf16> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} : +// CHECK-SAME: memref, !x86.amx.tile<16x32xbf16> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} : +// CHECK-SAME: memref, !x86.amx.tile<16x32xbf16> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] : +// CHECK-SAME: memref>, !x86.amx.tile<16x32xbf16> +func.func @tloadstore(%stride: index, + %arg0: memref, + %arg1: memref, + %arg2: memref>) { + %0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = x86.amx.tile_load %arg0[%0], %stride : memref into !x86.amx.tile<16x32xbf16> + %2 = x86.amx.tile_load %arg1[%0, %0], %stride : memref into !x86.amx.tile<16x32xbf16> + %3 = x86.amx.tile_load %arg2[%0, %0] : memref> into !x86.amx.tile<16x32xbf16> + x86.amx.tile_store %arg0[%0], %3, %stride : memref, !x86.amx.tile<16x32xbf16> + x86.amx.tile_store %arg1[%0, %0], %1, %stride : memref, !x86.amx.tile<16x32xbf16> + x86.amx.tile_store %arg2[%0, %0], %2 : memref>, !x86.amx.tile<16x32xbf16> + return +} + +// CHECK-LABEL: tzero +// CHECK: x86.amx.tile_zero : !x86.amx.tile<16x16xbf16> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, !x86.amx.tile<16x16xbf16> +func.func @tzero(%arg0: memref) { + %0 = arith.constant 0 : index + %1 = x86.amx.tile_zero : !x86.amx.tile<16x16xbf16> + x86.amx.tile_store %arg0[%0, %0], %1 : memref, !x86.amx.tile<16x16xbf16> + return +} + +// CHECK-LABEL: tmulf +// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !x86.amx.tile<16x32xbf16> +// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !x86.amx.tile<16x16xf32> +// CHECK: %[[m:.*]] = x86.amx.tile_mulf %[[x]], %[[x]], %[[z]] : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, !x86.amx.tile<16x16xf32> +func.func @tmulf(%arg0: memref, %arg1: memref) { + %0 = arith.constant 0 : index + %1 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x32xbf16> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref into !x86.amx.tile<16x16xf32> + %3 = x86.amx.tile_mulf %1, %1, %2 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> + x86.amx.tile_store %arg1[%0, %0], %3 : memref, !x86.amx.tile<16x16xf32> + return +} + +// CHECK-LABEL: tmuli +// CHECK: %[[x:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !x86.amx.tile<16x64xi8> +// CHECK: %[[y:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !x86.amx.tile<16x64xi8> +// CHECK: %[[z:.*]] = x86.amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into !x86.amx.tile<16x16xi32> +// CHECK: %[[m:.*]] = x86.amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> +// CHECK: x86.amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, !x86.amx.tile<16x16xi32> +// Verify the parsing/printing of the sign-extension annotation. +// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}} +// CHECK: x86.amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}} +// CHECK: x86.amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}} +func.func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { + %0 = arith.constant 0 : index + %1 = x86.amx.tile_load %arg0[%0, %0] : memref into !x86.amx.tile<16x64xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref into !x86.amx.tile<16x64xi8> + %3 = x86.amx.tile_load %arg2[%0, %0] : memref into !x86.amx.tile<16x16xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref, !x86.amx.tile<16x16xi32> + // Verify the various `zext` combinations. + %5 = x86.amx.tile_muli %1, %2 zext, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + %6 = x86.amx.tile_muli %1 zext, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + %7 = x86.amx.tile_muli %1, %2, %3 : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + return +} diff --git a/mlir/test/Dialect/X86/AMX/side-effects.mlir b/mlir/test/Dialect/X86/AMX/side-effects.mlir new file mode 100644 index 000000000000..fa475f34068e --- /dev/null +++ b/mlir/test/Dialect/X86/AMX/side-effects.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -cse -convert-vector-to-llvm="enable-x86" | FileCheck %s + +// With inclusion of memory side-effects, it is expected CSE not to fold multiple +// "tileload" and "tilezero". +// CHECK-LABEL: do_not_fold_tiles( +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tilezero.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal" +func.func @do_not_fold_tiles(%arg0: memref<2x32x32xbf16>, %arg1: memref<2x16x32xbf16>) -> memref<16x32xf32> { + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %alloca = memref.alloca() : memref<16x32xf32> + %0 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32> + %1 = x86.amx.tile_zero : !x86.amx.tile<16x16xf32> + %2:2 = scf.for %arg2 = %c0 to %c2 step %c1 iter_args(%arg3 = %0, %arg4 = %1) -> (!x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32>) { + %3 = x86.amx.tile_load %arg0[%arg2, %c0, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16> + %4 = x86.amx.tile_load %arg0[%arg2, %c16, %c0] : memref<2x32x32xbf16> into !x86.amx.tile<16x32xbf16> + %5 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16> + %6 = x86.amx.tile_load %arg1[%arg2, %c0, %c0] : memref<2x16x32xbf16> into !x86.amx.tile<16x32xbf16> + %7 = x86.amx.tile_mulf %3, %5, %arg3 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> + %8 = x86.amx.tile_mulf %4, %6, %arg4 : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> + scf.yield %7, %8 : !x86.amx.tile<16x16xf32>, !x86.amx.tile<16x16xf32> + } + x86.amx.tile_store %alloca[%c0, %c0], %2#0 : memref<16x32xf32>, !x86.amx.tile<16x16xf32> + x86.amx.tile_store %alloca[%c0, %c16], %2#1 : memref<16x32xf32>, !x86.amx.tile<16x16xf32> + return %alloca : memref<16x32xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg similarity index 91% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg index 70b4b66f4378..df9057d8933c 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg @@ -1,7 +1,7 @@ import sys # AMX tests must be enabled via build flag. -if not config.mlir_run_amx_tests: +if not config.mlir_run_x86_amx_tests: config.unsupported = True # No JIT on win32. diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir similarity index 95% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir index 8014bb7d2dcc..bd67a50cffb2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" \ -// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \ +// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" \ // RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \ @@ -14,11 +14,11 @@ func.func @kernel(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16>, %arg2: memref<16x16xf32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> - %3 = amx.tile_zero : vector<16x16xf32> - %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16> + %3 = x86.amx.tile_zero : vector<16x16xf32> + %4 = x86.amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir similarity index 74% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir index 5f7250f4d4cc..f1ff2bdf902f 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ // RUN: FileCheck %s @@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x4xbf16>, %arg1: memref<2x4xbf16>, %arg2: memref<2x2xf32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> - %3 = amx.tile_zero : vector<2x2xf32> - %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> - amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = x86.amx.tile_zero : vector<2x2xf32> + %4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> return } @@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x4xbf16>, %arg1: memref<2x4xbf16>, %arg2: memref<2x2xf32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> - %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32> - %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> - amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32> + %4 = x86.amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir similarity index 83% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir index 5c0618c2e5e5..c572cff40f28 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-ext.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ // RUN: FileCheck %s @@ -21,11 +21,11 @@ func.func @kernel1(%arg0: memref<16x16xi8>, %arg1: memref<4x16xi8>, %arg2: memref<16x4xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> - %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> + %3 = x86.amx.tile_zero : vector<16x4xi32> + %4 = x86.amx.tile_muli %1, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -33,11 +33,11 @@ func.func @kernel2(%arg0: memref<16x16xi8>, %arg1: memref<4x16xi8>, %arg2: memref<16x4xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> - %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> + %3 = x86.amx.tile_zero : vector<16x4xi32> + %4 = x86.amx.tile_muli %1, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -45,11 +45,11 @@ func.func @kernel3(%arg0: memref<16x16xi8>, %arg1: memref<4x16xi8>, %arg2: memref<16x4xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> - %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> + %3 = x86.amx.tile_zero : vector<16x4xi32> + %4 = x86.amx.tile_muli %1 zext, %2, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } @@ -57,11 +57,11 @@ func.func @kernel4(%arg0: memref<16x16xi8>, %arg1: memref<4x16xi8>, %arg2: memref<16x4xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> - %3 = amx.tile_zero : vector<16x4xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x16xi8> into vector<16x16xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<4x16xi8> into vector<4x16xi8> + %3 = x86.amx.tile_zero : vector<16x4xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x16xi8>, vector<4x16xi8>, vector<16x4xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x4xi32>, vector<16x4xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir similarity index 95% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir index a0076db6660d..7208389f4cbf 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli-full.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \ // RUN: -one-shot-bufferize="bufferize-function-boundaries" \ // RUN: -convert-scf-to-cf \ -// RUN: -convert-vector-to-llvm="enable-amx" \ +// RUN: -convert-vector-to-llvm="enable-x86" \ // RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \ @@ -15,11 +15,11 @@ func.func @kernel(%arg0: memref<16x64xi8>, %arg1: memref<16x64xi8>, %arg2: memref<16x16xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8> - %3 = amx.tile_zero : vector<16x16xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<16x64xi8> into vector<16x64xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<16x64xi8> into vector<16x64xi8> + %3 = x86.amx.tile_zero : vector<16x16xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<16x16xi32>, vector<16x16xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir similarity index 74% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir index 7b14df8dbd85..cd0b84c3a188 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/muli.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ // RUN: FileCheck %s @@ -10,11 +10,11 @@ func.func @kernel1(%arg0: memref<2x8xi8>, %arg1: memref<2x8xi8>, %arg2: memref<2x2xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> - %3 = amx.tile_zero : vector<2x2xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = x86.amx.tile_zero : vector<2x2xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> return } @@ -23,11 +23,11 @@ func.func @kernel2(%arg0: memref<2x8xi8>, %arg1: memref<2x8xi8>, %arg2: memref<2x2xi32>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> - %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> - %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> - %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> - amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + %1 = x86.amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = x86.amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = x86.amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> + %4 = x86.amx.tile_muli %1 zext, %2 zext, %3 : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + x86.amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir similarity index 94% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir index e35c555f0a85..e6676c441124 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero-block.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ // RUN: FileCheck %s @@ -25,8 +25,8 @@ func.func @kernel(%arg0: memref<4x32xf32>) { %c32 = arith.constant 32 : index scf.for %i = %c0 to %c4 step %c2 { scf.for %j = %c0 to %c32 step %c16 { - %0 = amx.tile_zero : vector<2x16xf32> - amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32> + %0 = x86.amx.tile_zero : vector<2x16xf32> + x86.amx.tile_store %arg0[%i, %j], %0 : memref<4x32xf32>, vector<2x16xf32> func.call @print(%arg0) : (memref<4x32xf32>) -> () } } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir similarity index 96% rename from mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir rename to mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir index 37db0333e3f5..09ae8f1a9514 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AMX/tilezero.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-cf -convert-vector-to-llvm="enable-x86" -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_c_runner_utils | \ // RUN: FileCheck %s @@ -6,8 +6,8 @@ // Note: To run this test, your CPU must support AMX. func.func @tilezero(%arg0: memref, %i: index, %j: index) { - %1 = amx.tile_zero : vector<16x16xi32> - amx.tile_store %arg0[%i, %j], %1 : memref, vector<16x16xi32> + %1 = x86.amx.tile_zero : vector<16x16xi32> + x86.amx.tile_store %arg0[%i, %j], %1 : memref, vector<16x16xi32> return } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir index 5570da8fe04b..c375350de50d 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir @@ -5,14 +5,15 @@ func.func @entry() -> i32 { %i0 = arith.constant 0 : i32 - %i4 = arith.constant 4 : i32 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index %a = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : vector<8xf32> %b = arith.constant dense<[9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : vector<8xf32> %r = x86.avx.intr.dot %a, %b : vector<8xf32> - %1 = vector.extract %r[%i0] : f32 from vector<8xf32> - %2 = vector.extract %r[%i4] : f32 from vector<8xf32> + %1 = vector.extract %r[%c0] : f32 from vector<8xf32> + %2 = vector.extract %r[%c4] : f32 from vector<8xf32> %d = arith.addf %1, %2 : f32 // CHECK: ( 110, 110, 110, 110, 382, 382, 382, 382 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir index 4f3f70a45a50..7b0f505a4778 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir @@ -180,8 +180,7 @@ func.func @memref_dot_optimized(%m_A : memref, %m_B : memref, -> f64 { // Helper constants for loops. %c0 = arith.constant 0 : index - %i0 = arith.constant 0 : i32 - %i7 = arith.constant 7 : i32 + %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %data_zero = arith.constant 0.0 : f64 @@ -196,13 +195,13 @@ func.func @memref_dot_optimized(%m_A : memref, %m_B : memref, iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { %v_A = vector.transfer_read %m_A[%a], %index_padding : memref, vector<8xi64> - %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> + %segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64> %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { %v_C = vector.transfer_read %m_C[%b], %index_padding : memref, vector<8xi64> - %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { @@ -251,8 +250,7 @@ func.func @memref_dot_while(%m_A : memref, %m_B : memref, -> f64 { // Helper constants for loops. %c0 = arith.constant 0 : index - %i0 = arith.constant 0 : i32 - %i7 = arith.constant 7 : i32 + %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %data_zero = arith.constant 0.0 : f64 @@ -273,10 +271,10 @@ func.func @memref_dot_while(%m_A : memref, %m_B : memref, %v_C = vector.transfer_read %m_C[%b1], %index_padding : memref, vector<8xi64> - %segA_min = vector.extract %v_A[%i0] : i64 from vector<8xi64> - %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> - %segB_min = vector.extract %v_C[%i0] : i64 from vector<8xi64> - %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> + %segA_min = vector.extract %v_A[%c0] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64> + %segB_min = vector.extract %v_C[%c0] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64> %seg1_done = arith.cmpi "slt", %segB_max, %segA_min : i64 %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { @@ -340,7 +338,7 @@ func.func @memref_dot_while_branchless(%m_A : memref, %m_B : memref f64 { // Helper constants for loops. %c0 = arith.constant 0 : index - %i7 = arith.constant 7 : i32 + %c7 = arith.constant 7 : index %c8 = arith.constant 8 : index %data_zero = arith.constant 0.0 : f64 @@ -370,8 +368,8 @@ func.func @memref_dot_while_branchless(%m_A : memref, %m_B : memref f64 %r2 = arith.addf %r1, %subresult : f64 - %segA_max = vector.extract %v_A[%i7] : i64 from vector<8xi64> - %segB_max = vector.extract %v_C[%i7] : i64 from vector<8xi64> + %segA_max = vector.extract %v_A[%c7] : i64 from vector<8xi64> + %segB_max = vector.extract %v_C[%c7] : i64 from vector<8xi64> %cond_a = arith.cmpi "sle", %segA_max, %segB_max : i64 %cond_a_i64 = arith.extui %cond_a : i1 to i64 diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index 160a9ced46e2..4a4be24c2e3a 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --convert-vector-to-llvm="enable-amx" --convert-to-llvm -reconcile-unrealized-casts \ +// RUN: mlir-opt %s --convert-vector-to-llvm="enable-x86" --convert-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-translate --mlir-to-llvmir \ // RUN: | FileCheck %s @@ -7,8 +7,8 @@ func.func @amx_tile_zero(%out: memref, %idx: index) { // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) // CHECK: call void @llvm.x86.tilestored64.internal - %zero = amx.tile_zero : !amx.tile<16x16xf32> - amx.tile_store %out[%idx, %idx], %zero : memref, !amx.tile<16x16xf32> + %zero = x86.amx.tile_zero : !x86.amx.tile<16x16xf32> + x86.amx.tile_store %out[%idx, %idx], %zero : memref, !x86.amx.tile<16x16xf32> return } @@ -18,8 +18,8 @@ func.func @amx_tile_load_store(%base: memref, %out: memref, { // CHECK: call x86_amx @llvm.x86.tileloadd64.internal // CHECK: call void @llvm.x86.tilestored64.internal - %val = amx.tile_load %base[%idx, %idx] : memref into !amx.tile<16x64xi8> - amx.tile_store %out[%idx, %idx], %val : memref, !amx.tile<16x64xi8> + %val = x86.amx.tile_load %base[%idx, %idx] : memref into !x86.amx.tile<16x64xi8> + x86.amx.tile_store %out[%idx, %idx], %val : memref, !x86.amx.tile<16x64xi8> return } @@ -29,10 +29,10 @@ func.func @amx_tile_load_store_strided(%base: memref, %out: memref, { // CHECK: call x86_amx @llvm.x86.tileloadd64.internal // CHECK: call void @llvm.x86.tilestored64.internal - %val = amx.tile_load %base[%idx], %stride - : memref into !amx.tile<16x64xi8> - amx.tile_store %out[%idx], %val, %stride - : memref, !amx.tile<16x64xi8> + %val = x86.amx.tile_load %base[%idx], %stride + : memref into !x86.amx.tile<16x64xi8> + x86.amx.tile_store %out[%idx], %val, %stride + : memref, !x86.amx.tile<16x64xi8> return } @@ -42,15 +42,15 @@ func.func @amx_tile_mulf_bf16( %out: memref) { // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) - %acc = amx.tile_zero : !amx.tile<16x16xf32> + %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32> // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal - %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xbf16> - %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xbf16> + %tA = x86.amx.tile_load %matA[%idx, %idx] : memref into !x86.amx.tile<16x32xbf16> + %tB = x86.amx.tile_load %matB[%idx, %idx] : memref into !x86.amx.tile<16x32xbf16> // CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal - %tRes = amx.tile_mulf %tA, %tB, %acc - : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> + %tRes = x86.amx.tile_mulf %tA, %tB, %acc + : !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x32xbf16>, !x86.amx.tile<16x16xf32> // CHECK: call void @llvm.x86.tilestored64.internal - amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + x86.amx.tile_store %out[%idx, %idx], %tRes : memref, !x86.amx.tile<16x16xf32> return } @@ -60,15 +60,15 @@ func.func @amx_tile_mulf_f16( %out: memref) { // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) - %acc = amx.tile_zero : !amx.tile<16x16xf32> + %acc = x86.amx.tile_zero : !x86.amx.tile<16x16xf32> // CHECK-COUNT-2: call x86_amx @llvm.x86.tileloadd64.internal - %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x32xf16> - %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x32xf16> + %tA = x86.amx.tile_load %matA[%idx, %idx] : memref into !x86.amx.tile<16x32xf16> + %tB = x86.amx.tile_load %matB[%idx, %idx] : memref into !x86.amx.tile<16x32xf16> // CHECK: call x86_amx @llvm.x86.tdpfp16ps.internal - %tRes = amx.tile_mulf %tA, %tB, %acc - : !amx.tile<16x32xf16>, !amx.tile<16x32xf16>, !amx.tile<16x16xf32> + %tRes = x86.amx.tile_mulf %tA, %tB, %acc + : !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x32xf16>, !x86.amx.tile<16x16xf32> // CHECK: call void @llvm.x86.tilestored64.internal - amx.tile_store %out[%idx, %idx], %tRes : memref, !amx.tile<16x16xf32> + x86.amx.tile_store %out[%idx, %idx], %tRes : memref, !x86.amx.tile<16x16xf32> return } @@ -79,26 +79,26 @@ func.func @amx_tile_muli(%matA: memref, %matB: memref, %c0 = arith.constant 0 : index %c16 = arith.constant 16 : index // CHECK-COUNT-3: call x86_amx @llvm.x86.tileloadd64.internal - %tA = amx.tile_load %matA[%idx, %idx] : memref into !amx.tile<16x64xi8> - %tB = amx.tile_load %matB[%idx, %idx] : memref into !amx.tile<16x64xi8> - %acc = amx.tile_load %matC[%idx, %idx] : memref into !amx.tile<16x16xi32> + %tA = x86.amx.tile_load %matA[%idx, %idx] : memref into !x86.amx.tile<16x64xi8> + %tB = x86.amx.tile_load %matB[%idx, %idx] : memref into !x86.amx.tile<16x64xi8> + %acc = x86.amx.tile_load %matC[%idx, %idx] : memref into !x86.amx.tile<16x16xi32> // CHECK: call x86_amx @llvm.x86.tdpbuud.internal // CHECK: call x86_amx @llvm.x86.tdpbssd.internal // CHECK: call x86_amx @llvm.x86.tdpbusd.internal // CHECK: call x86_amx @llvm.x86.tdpbsud.internal - %res = amx.tile_muli %tA zext, %tB zext, %acc - : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - %res1 = amx.tile_muli %tA, %tB, %acc - : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - %res2 = amx.tile_muli %tA zext, %tB, %acc - : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> - %res3 = amx.tile_muli %tA, %tB zext, %acc - : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> + %res = x86.amx.tile_muli %tA zext, %tB zext, %acc + : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + %res1 = x86.amx.tile_muli %tA, %tB, %acc + : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + %res2 = x86.amx.tile_muli %tA zext, %tB, %acc + : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> + %res3 = x86.amx.tile_muli %tA, %tB zext, %acc + : !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x64xi8>, !x86.amx.tile<16x16xi32> // CHECK-COUNT-4: call void @llvm.x86.tilestored64.internal - amx.tile_store %out[%c0, %c0], %res : memref, !amx.tile<16x16xi32> - amx.tile_store %out[%c0, %c16], %res1 : memref, !amx.tile<16x16xi32> - amx.tile_store %out[%c16, %c0], %res2 : memref, !amx.tile<16x16xi32> - amx.tile_store %out[%c16, %c16], %res3 : memref, !amx.tile<16x16xi32> + x86.amx.tile_store %out[%c0, %c0], %res : memref, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %out[%c0, %c16], %res1 : memref, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %out[%c16, %c0], %res2 : memref, !x86.amx.tile<16x16xi32> + x86.amx.tile_store %out[%c16, %c16], %res3 : memref, !x86.amx.tile<16x16xi32> return } @@ -108,16 +108,16 @@ func.func @amx_tile_type_through_cf(%src: memref, %out: memref, cf.cond_br %cond, ^bb1, ^bb2 ^bb1: // pred: ^bb0 // CHECK: call x86_amx @llvm.x86.tileloadd64.internal - %0 = amx.tile_load %src[%idx, %idx] : memref into !amx.tile<16x64xi8> - cf.br ^bb3(%0 : !amx.tile<16x64xi8>) + %0 = x86.amx.tile_load %src[%idx, %idx] : memref into !x86.amx.tile<16x64xi8> + cf.br ^bb3(%0 : !x86.amx.tile<16x64xi8>) ^bb2: // pred: ^bb0 // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) - %1 = amx.tile_zero : !amx.tile<16x64xi8> - cf.br ^bb3(%1 : !amx.tile<16x64xi8>) -^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2 + %1 = x86.amx.tile_zero : !x86.amx.tile<16x64xi8> + cf.br ^bb3(%1 : !x86.amx.tile<16x64xi8>) +^bb3(%2: !x86.amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2 cf.br ^bb4 ^bb4: // pred: ^bb3 // CHECK: call void @llvm.x86.tilestored64.internal - amx.tile_store %out[%idx, %idx], %2 : memref, !amx.tile<16x64xi8> + x86.amx.tile_store %out[%idx, %idx], %2 : memref, !x86.amx.tile<16x64xi8> return } diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in index 0f3a01487594..1bbe74ed3fa5 100644 --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -46,7 +46,7 @@ config.enable_vulkan_runner = @MLIR_ENABLE_VULKAN_RUNNER@ config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ config.enable_python_stable_abi = @MLIR_ENABLE_PYTHON_STABLE_ABI@ config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@" -config.mlir_run_amx_tests = @MLIR_RUN_AMX_TESTS@ +config.mlir_run_x86_amx_tests = @MLIR_RUN_X86_AMX_TESTS@ config.mlir_run_arm_sve_tests = @MLIR_RUN_ARM_SVE_TESTS@ # This is a workaround for the fact that LIT's: # %if diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir index 9724eb42119a..fb344880598c 100644 --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -3,7 +3,6 @@ // CHECK-SAME: acc // CHECK-SAME: affine // CHECK-SAME: amdgpu -// CHECK-SAME: amx // CHECK-SAME: arith // CHECK-SAME: arm_neon // CHECK-SAME: arm_sme